from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch # Initialize FastAPI app app = FastAPI() # Load the Falcon-7B model with 8-bit quantization (if CUDA is available) model_id = "tiiuae/falcon-7b-instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) # Check if CUDA is available if torch.cuda.is_available(): # Load the model with 8-bit quantization for GPU model = AutoModelForCausalLM.from_pretrained( model_id, load_in_8bit=True, device_map="auto", trust_remote_code=True ) else: # Fallback to CPU or full precision model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", trust_remote_code=True ) # Create a text generation pipeline pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) # Define request body schema class TextGenerationRequest(BaseModel): prompt: str max_new_tokens: int = 50 temperature: float = 0.7 top_k: int = 50 top_p: float = 0.9 do_sample: bool = True # Define API endpoint @app.post("/generate-text") async def generate_text(request: TextGenerationRequest): try: # Generate text using the pipeline outputs = pipe( request.prompt, max_new_tokens=request.max_new_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, do_sample=request.do_sample ) return {"generated_text": outputs[0]["generated_text"]} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Add a root endpoint for health checks @app.get("/test") async def root(): return {"message": "API is running!"}