|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import requests |
|
import os |
|
import logging |
|
|
|
|
|
token = os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
if token is None: |
|
raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" |
|
headers = {"Authorization": f"Bearer {token}"} |
|
|
|
|
|
class TextGenerationRequest(BaseModel): |
|
prompt: str |
|
max_new_tokens: int = 150 |
|
temperature: float = 0.8 |
|
top_k: int = 100 |
|
top_p: float = 0.92 |
|
repetition_penalty: float = 1.2 |
|
do_sample: bool = True |
|
|
|
|
|
def remove_repetition(text: str) -> str: |
|
seen = set() |
|
result = [] |
|
for word in text.split(): |
|
if word.lower() not in seen: |
|
result.append(word) |
|
seen.add(word.lower()) |
|
return " ".join(result) |
|
|
|
|
|
@app.post("/generate-text") |
|
async def generate_text(request: TextGenerationRequest): |
|
try: |
|
logger.info("Generating text...") |
|
|
|
|
|
payload = { |
|
"inputs": request.prompt, |
|
"parameters": { |
|
"max_new_tokens": request.max_new_tokens, |
|
"temperature": request.temperature, |
|
"top_k": request.top_k, |
|
"top_p": request.top_p, |
|
"repetition_penalty": request.repetition_penalty, |
|
"do_sample": request.do_sample, |
|
}, |
|
} |
|
|
|
|
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
|
|
|
|
if response.status_code != 200: |
|
logger.error(f"API Error: {response.status_code} - {response.text}") |
|
raise HTTPException(status_code=response.status_code, detail=response.text) |
|
|
|
|
|
generated_text = response.json()[0]["generated_text"] |
|
cleaned_text = remove_repetition(generated_text) |
|
return {"generated_text": cleaned_text} |
|
except Exception as e: |
|
logger.error(f"Error generating text: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."} |
|
|
|
|
|
@app.get("/test") |
|
async def test(): |
|
return {"message": "API is running!"} |
|
|