808-GPT2 / app.py
Fred808's picture
Update app.py
9ddd59a verified
raw
history blame
2.38 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests
import os
import logging
# Read the Hugging Face token from the environment variable
token = os.getenv("HUGGING_FACE_HUB_TOKEN")
if token is None:
raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.")
# Initialize FastAPI app
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Hugging Face Inference API endpoint for BLOOM-7B
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom-7b1" # Use BLOOM-7B
headers = {"Authorization": f"Bearer {token}"}
# Define request body schema
class TextGenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 50 # Reduce this for faster responses
temperature: float = 0.7 # Lower for more deterministic outputs
top_k: int = 50
top_p: float = 0.9
do_sample: bool = True
# Define API endpoint
@app.post("/generate-text")
async def generate_text(request: TextGenerationRequest):
try:
logger.info("Generating text...")
# Prepare the payload for the Hugging Face Inference API
payload = {
"inputs": request.prompt,
"parameters": {
"max_new_tokens": request.max_new_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"do_sample": request.do_sample,
},
}
# Send request to the Hugging Face Inference API
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # Raise an error for bad responses (4xx or 5xx)
# Extract the generated text from the response
generated_text = response.json()[0]["generated_text"]
return {"generated_text": generated_text}
except Exception as e:
logger.error(f"Error generating text: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Add a root endpoint for health checks
@app.get("/")
async def root():
return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."}
# Add a test endpoint
@app.get("/test")
async def test():
return {"message": "API is running!"}