File size: 3,157 Bytes
64c0b0e 9ddd59a 624df6b e0e5738 64c0b0e 624df6b 64c0b0e e0e5738 dd25f43 28bfc95 9ddd59a 64c0b0e 0fb92e1 64c0b0e 0fb92e1 64c0b0e e0e5738 4c64189 9ddd59a 0fb92e1 9ddd59a dd25f43 9ddd59a 0fb92e1 9ddd59a 0fb92e1 64c0b0e e0e5738 64c0b0e acbb541 64c0b0e acbb541 0fb92e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests
import os
import logging
# Read the Hugging Face token from the environment variable
token = os.getenv("HUGGING_FACE_HUB_TOKEN")
if token is None:
raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.")
# Initialize FastAPI app
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Hugging Face Inference API endpoint for DeepSeek
API_URL = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # Replace with the correct model ID
headers = {"Authorization": f"Bearer {token}"}
# Define request body schema
class TextGenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 150 # Increased to allow longer responses
temperature: float = 0.8 # Slightly increased for more varied responses
top_k: int = 100 # Broader vocabulary sampling
top_p: float = 0.92 # Increased diversity
repetition_penalty: float = 1.2 # Penalizes repetition
do_sample: bool = True
# Function to remove repetitive content in the output
def remove_repetition(text: str) -> str:
seen = set()
result = []
for word in text.split():
if word.lower() not in seen:
result.append(word)
seen.add(word.lower())
return " ".join(result)
# Define API endpoint
@app.post("/generate-text")
async def generate_text(request: TextGenerationRequest):
try:
logger.info("Generating text...")
# Prepare the payload for the Hugging Face Inference API
payload = {
"inputs": request.prompt,
"parameters": {
"max_new_tokens": request.max_new_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"do_sample": request.do_sample,
},
}
# Send request to the Hugging Face Inference API
response = requests.post(API_URL, headers=headers, json=payload)
# Check for errors in the response
if response.status_code != 200:
logger.error(f"API Error: {response.status_code} - {response.text}")
raise HTTPException(status_code=response.status_code, detail=response.text)
# Extract and process the generated text
generated_text = response.json()[0]["generated_text"]
cleaned_text = remove_repetition(generated_text) # Clean up repetition
return {"generated_text": cleaned_text}
except Exception as e:
logger.error(f"Error generating text: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Add a root endpoint for health checks
@app.get("/")
async def root():
return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."}
# Add a test endpoint
@app.get("/test")
async def test():
return {"message": "API is running!"}
|