MilindChawre's picture
Updating README and splitting training logic
b7ca7fe
import streamlit as st
import torch
import tiktoken
from transformer import GPT, GPTConfig # Ensure you import your model class
# Load the trained model
@st.cache_resource
def load_model():
config = GPTConfig()
model = GPT(config)
try:
# Load the model with map_location to handle CPU-only environments
model.load_state_dict(torch.load('trained_model_quantized.pt', map_location=torch.device('cpu')), strict=False)
model.eval() # Set the model to evaluation mode
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading model: {e}")
return model
# Load the tokenizer
def load_tokenizer():
return tiktoken.get_encoding('gpt2')
# Generate text function
def generate_text(model, tokenizer, input_text, length, num_sequences):
# Encode the input text
input_ids = tokenizer.encode(input_text)
input_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension (shape: [1, T])
generated_sequences = []
for _ in range(num_sequences):
# Generate additional tokens
with torch.no_grad():
for _ in range(length):
logits = model(input_tensor)[0] # Get logits
next_token_logits = logits[:, -1, :] # Get the last token's logits
next_token_probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(next_token_probs, num_samples=1) # Sample from the distribution
# Ensure the next_token has the correct shape for concatenation
next_token = next_token.view(1, -1) # Reshape to [1, 1] if necessary
input_tensor = torch.cat((input_tensor, next_token), dim=1) # Append the new token
# Decode the generated tokens
generated_sequences.append(tokenizer.decode(input_tensor[0].tolist()))
return generated_sequences
# Streamlit app layout
st.title("GPT Text Generator")
st.write("Enter your text and specify the length of additional text to generate.")
input_text = st.text_area("Input Text", "Once upon a time", max_chars=512) # Limit to 512 characters
length = st.slider("Predict Additional Text of Length", 1, 50, 10)
num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1)
if st.button("Generate"):
model = load_model() # Load the model for inference
tokenizer = load_tokenizer() # Load the tokenizer
st.write("Generating text...")
generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences)
st.write("Text generation complete.")
st.write("Generated Texts:")
for i, text in enumerate(generated_texts):
st.subheader(f"Sequence {i + 1}")
st.write(text)