from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch import logging # Initialize FastAPI app app = FastAPI() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load the Google Gemma 7B model and tokenizer model_id = "google/gemma-7b" # Use Google Gemma 7B tokenizer = AutoTokenizer.from_pretrained(model_id) # Load the model with 4-bit quantization to reduce VRAM usage model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, # Use half-precision for faster inference device_map="auto", # Automatically offload to available GPUs load_in_4bit=True # Enable 4-bit quantization ) # Create a text generation pipeline pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device="cuda" if torch.cuda.is_available() else "cpu" ) # Define request body schema class TextGenerationRequest(BaseModel): prompt: str max_new_tokens: int = 50 # Reduce this for faster responses temperature: float = 0.7 # Lower for more deterministic outputs top_k: int = 50 top_p: float = 0.9 do_sample: bool = True # Define API endpoint @app.post("/generate-text") async def generate_text(request: TextGenerationRequest): try: logger.info("Generating text...") # Generate text using the pipeline with the user's prompt outputs = pipe( request.prompt, # Use the user's prompt directly max_new_tokens=request.max_new_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, do_sample=request.do_sample, return_full_text=False # Exclude the input prompt from the output ) return {"generated_text": outputs[0]["generated_text"]} except Exception as e: logger.error(f"Error generating text: {e}") raise HTTPException(status_code=500, detail=str(e)) # Add a root endpoint for health checks @app.get("/test") async def root(): return {"message": "API is running!"}