from fastapi import FastAPI, HTTPException from pydantic import BaseModel import requests import os import logging # Read the Hugging Face token from the environment variable token = os.getenv("HUGGING_FACE_HUB_TOKEN") if token is None: raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.") # Initialize FastAPI app app = FastAPI() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Hugging Face Inference API endpoint for DeepSeek API_URL = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # Replace with the correct model ID headers = {"Authorization": f"Bearer {token}"} # Define request body schema class TextGenerationRequest(BaseModel): prompt: str max_new_tokens: int = 150 # Increased to allow longer responses temperature: float = 0.8 # Slightly increased for more varied responses top_k: int = 100 # Broader vocabulary sampling top_p: float = 0.92 # Increased diversity repetition_penalty: float = 1.2 # Penalizes repetition do_sample: bool = True # Function to remove repetitive content in the output def remove_repetition(text: str) -> str: seen = set() result = [] for word in text.split(): if word.lower() not in seen: result.append(word) seen.add(word.lower()) return " ".join(result) # Define API endpoint @app.post("/generate-text") async def generate_text(request: TextGenerationRequest): try: logger.info("Generating text...") # Prepare the payload for the Hugging Face Inference API payload = { "inputs": request.prompt, "parameters": { "max_new_tokens": request.max_new_tokens, "temperature": request.temperature, "top_k": request.top_k, "top_p": request.top_p, "repetition_penalty": request.repetition_penalty, "do_sample": request.do_sample, }, } # Send request to the Hugging Face Inference API response = requests.post(API_URL, headers=headers, json=payload) # Check for errors in the response if response.status_code != 200: logger.error(f"API Error: {response.status_code} - {response.text}") raise HTTPException(status_code=response.status_code, detail=response.text) # Extract and process the generated text generated_text = response.json()[0]["generated_text"] cleaned_text = remove_repetition(generated_text) # Clean up repetition return {"generated_text": cleaned_text} except Exception as e: logger.error(f"Error generating text: {e}") raise HTTPException(status_code=500, detail=str(e)) # Add a root endpoint for health checks @app.get("/") async def root(): return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."} # Add a test endpoint @app.get("/test") async def test(): return {"message": "API is running!"}