File size: 3,157 Bytes
64c0b0e
 
9ddd59a
624df6b
e0e5738
64c0b0e
624df6b
 
 
 
 
64c0b0e
 
 
e0e5738
 
 
 
dd25f43
28bfc95
9ddd59a
64c0b0e
 
 
 
0fb92e1
 
 
 
 
64c0b0e
 
0fb92e1
 
 
 
 
 
 
 
 
 
64c0b0e
 
 
 
e0e5738
4c64189
9ddd59a
 
 
 
 
 
 
 
0fb92e1
9ddd59a
 
 
 
 
 
dd25f43
 
 
 
 
9ddd59a
0fb92e1
9ddd59a
0fb92e1
 
64c0b0e
e0e5738
64c0b0e
 
 
acbb541
64c0b0e
acbb541
 
 
 
 
0fb92e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests
import os
import logging

# Read the Hugging Face token from the environment variable
token = os.getenv("HUGGING_FACE_HUB_TOKEN")
if token is None:
    raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.")

# Initialize FastAPI app
app = FastAPI()

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Hugging Face Inference API endpoint for DeepSeek
API_URL = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"  # Replace with the correct model ID
headers = {"Authorization": f"Bearer {token}"}

# Define request body schema
class TextGenerationRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 150  # Increased to allow longer responses
    temperature: float = 0.8  # Slightly increased for more varied responses
    top_k: int = 100  # Broader vocabulary sampling
    top_p: float = 0.92  # Increased diversity
    repetition_penalty: float = 1.2  # Penalizes repetition
    do_sample: bool = True

# Function to remove repetitive content in the output
def remove_repetition(text: str) -> str:
    seen = set()
    result = []
    for word in text.split():
        if word.lower() not in seen:
            result.append(word)
            seen.add(word.lower())
    return " ".join(result)

# Define API endpoint
@app.post("/generate-text")
async def generate_text(request: TextGenerationRequest):
    try:
        logger.info("Generating text...")
        
        # Prepare the payload for the Hugging Face Inference API
        payload = {
            "inputs": request.prompt,
            "parameters": {
                "max_new_tokens": request.max_new_tokens,
                "temperature": request.temperature,
                "top_k": request.top_k,
                "top_p": request.top_p,
                "repetition_penalty": request.repetition_penalty,
                "do_sample": request.do_sample,
            },
        }

        # Send request to the Hugging Face Inference API
        response = requests.post(API_URL, headers=headers, json=payload)
        
        # Check for errors in the response
        if response.status_code != 200:
            logger.error(f"API Error: {response.status_code} - {response.text}")
            raise HTTPException(status_code=response.status_code, detail=response.text)

        # Extract and process the generated text
        generated_text = response.json()[0]["generated_text"]
        cleaned_text = remove_repetition(generated_text)  # Clean up repetition
        return {"generated_text": cleaned_text}
    except Exception as e:
        logger.error(f"Error generating text: {e}")
        raise HTTPException(status_code=500, detail=str(e))

# Add a root endpoint for health checks
@app.get("/")
async def root():
    return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."}

# Add a test endpoint
@app.get("/test")
async def test():
    return {"message": "API is running!"}