from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch import logging # Initialize FastAPI app app = FastAPI() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load the GPT-2 model and tokenizer model_id = "gpt2" # Use GPT-2 tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) # Create a text generation pipeline pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) # Define the pre-prompt PRE_PROMPT = "You are a helpful virtual assistant. Answer the user's question clearly and concisely." # Define request body schema class TextGenerationRequest(BaseModel): prompt: str max_new_tokens: int = 50 # Reduce this for faster responses temperature: float = 0.7 # Lower for more deterministic outputs top_k: int = 50 top_p: float = 0.9 do_sample: bool = True # Define API endpoint @app.post("/generate-text") async def generate_text(request: TextGenerationRequest): try: logger.info("Generating text...") # Combine the pre-prompt and user's prompt combined_input = f"{PRE_PROMPT} {request.prompt}" # Generate text using the pipeline outputs = pipe( combined_input, # Use the combined input max_new_tokens=request.max_new_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, do_sample=request.do_sample, return_full_text=False # Exclude the input prompt from the output ) return {"generated_text": outputs[0]["generated_text"]} except Exception as e: logger.error(f"Error generating text: {e}") raise HTTPException(status_code=500, detail=str(e)) # Add a root endpoint for health checks @app.get("/test") async def root(): return {"message": "API is running!"}