from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from huggingface_hub import login import os import torch import logging # Read the Hugging Face token from the environment variable token = os.getenv("HUGGING_FACE_HUB_TOKEN") if token is None: raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.") # Log in with the token login(token=token) # Initialize FastAPI app app = FastAPI() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load a smaller model (GPT-Neo-125M) and tokenizer model_id = "EleutherAI/gpt-neo-125M" # Use GPT-Neo-125M for faster performance tokenizer = AutoTokenizer.from_pretrained(model_id) # Set pad_token if it doesn't exist if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Use eos_token as pad_token # Load the model without quantization for CPU logger.info("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, # Use FP32 for CPU compatibility device_map="auto" # Automatically offload to available devices ) # Create a text generation pipeline pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer ) # Define request body schema class TextGenerationRequest(BaseModel): prompt: str max_new_tokens: int = 50 # Reduce this for faster responses temperature: float = 0.7 # Lower for more deterministic outputs top_k: int = 50 top_p: float = 0.9 do_sample: bool = True # Define API endpoint @app.post("/generate-text") async def generate_text(request: TextGenerationRequest): try: logger.info("Generating text...") # Generate text using the pipeline with the user's prompt outputs = pipe( request.prompt, # Use the user's prompt directly max_new_tokens=request.max_new_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, do_sample=request.do_sample, return_full_text=False # Exclude the input prompt from the output ) return {"generated_text": outputs[0]["generated_text"]} except Exception as e: logger.error(f"Error generating text: {e}") raise HTTPException(status_code=500, detail=str(e)) # Add a root endpoint for health checks @app.get("/") async def root(): return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."} # Add a test endpoint @app.get("/test") async def test(): return {"message": "API is running!"}