Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import tiktoken | |
from transformer import GPT, GPTConfig # Ensure you import your model class | |
# Load the trained model | |
def load_model(): | |
config = GPTConfig() | |
model = GPT(config) | |
try: | |
# Load the model with map_location to handle CPU-only environments | |
model.load_state_dict(torch.load('trained_model_quantized.pt', map_location=torch.device('cpu')), strict=False) | |
model.eval() # Set the model to evaluation mode | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
return model | |
# Load the tokenizer | |
def load_tokenizer(): | |
return tiktoken.get_encoding('gpt2') | |
# Generate text function | |
def generate_text(model, tokenizer, input_text, length, num_sequences): | |
# Encode the input text | |
input_ids = tokenizer.encode(input_text) | |
input_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension (shape: [1, T]) | |
generated_sequences = [] | |
for _ in range(num_sequences): | |
# Generate additional tokens | |
with torch.no_grad(): | |
for _ in range(length): | |
logits = model(input_tensor)[0] # Get logits | |
next_token_logits = logits[:, -1, :] # Get the last token's logits | |
next_token_probs = torch.softmax(next_token_logits, dim=-1) | |
next_token = torch.multinomial(next_token_probs, num_samples=1) # Sample from the distribution | |
# Ensure the next_token has the correct shape for concatenation | |
next_token = next_token.view(1, -1) # Reshape to [1, 1] if necessary | |
input_tensor = torch.cat((input_tensor, next_token), dim=1) # Append the new token | |
# Decode the generated tokens | |
generated_sequences.append(tokenizer.decode(input_tensor[0].tolist())) | |
return generated_sequences | |
# Streamlit app layout | |
st.title("GPT Text Generator") | |
st.write("Enter your text and specify the length of additional text to generate.") | |
input_text = st.text_area("Input Text", "Once upon a time", max_chars=512) # Limit to 512 characters | |
length = st.slider("Predict Additional Text of Length", 1, 50, 10) | |
num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1) | |
if st.button("Generate"): | |
model = load_model() # Load the model for inference | |
tokenizer = load_tokenizer() # Load the tokenizer | |
st.write("Generating text...") | |
generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences) | |
st.write("Text generation complete.") | |
st.write("Generated Texts:") | |
for i, text in enumerate(generated_texts): | |
st.subheader(f"Sequence {i + 1}") | |
st.write(text) | |