808-GPT2 / app.py
Fred808's picture
Update app.py
28bfc95 verified
raw
history blame
3.16 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests
import os
import logging
# Read the Hugging Face token from the environment variable
token = os.getenv("HUGGING_FACE_HUB_TOKEN")
if token is None:
raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.")
# Initialize FastAPI app
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Hugging Face Inference API endpoint for DeepSeek
API_URL = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # Replace with the correct model ID
headers = {"Authorization": f"Bearer {token}"}
# Define request body schema
class TextGenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 150 # Increased to allow longer responses
temperature: float = 0.8 # Slightly increased for more varied responses
top_k: int = 100 # Broader vocabulary sampling
top_p: float = 0.92 # Increased diversity
repetition_penalty: float = 1.2 # Penalizes repetition
do_sample: bool = True
# Function to remove repetitive content in the output
def remove_repetition(text: str) -> str:
seen = set()
result = []
for word in text.split():
if word.lower() not in seen:
result.append(word)
seen.add(word.lower())
return " ".join(result)
# Define API endpoint
@app.post("/generate-text")
async def generate_text(request: TextGenerationRequest):
try:
logger.info("Generating text...")
# Prepare the payload for the Hugging Face Inference API
payload = {
"inputs": request.prompt,
"parameters": {
"max_new_tokens": request.max_new_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"do_sample": request.do_sample,
},
}
# Send request to the Hugging Face Inference API
response = requests.post(API_URL, headers=headers, json=payload)
# Check for errors in the response
if response.status_code != 200:
logger.error(f"API Error: {response.status_code} - {response.text}")
raise HTTPException(status_code=response.status_code, detail=response.text)
# Extract and process the generated text
generated_text = response.json()[0]["generated_text"]
cleaned_text = remove_repetition(generated_text) # Clean up repetition
return {"generated_text": cleaned_text}
except Exception as e:
logger.error(f"Error generating text: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Add a root endpoint for health checks
@app.get("/")
async def root():
return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."}
# Add a test endpoint
@app.get("/test")
async def test():
return {"message": "API is running!"}