|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from huggingface_hub import login |
|
import os |
|
import torch |
|
import logging |
|
|
|
|
|
token = os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
if token is None: |
|
raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.") |
|
|
|
|
|
login(token=token) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_id = "EleutherAI/gpt-neo-125M" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
logger.info("Loading model...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
device_map="auto" |
|
) |
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer |
|
) |
|
|
|
|
|
class TextGenerationRequest(BaseModel): |
|
prompt: str |
|
max_new_tokens: int = 50 |
|
temperature: float = 0.7 |
|
top_k: int = 50 |
|
top_p: float = 0.9 |
|
do_sample: bool = True |
|
|
|
|
|
@app.post("/generate-text") |
|
async def generate_text(request: TextGenerationRequest): |
|
try: |
|
logger.info("Generating text...") |
|
|
|
|
|
outputs = pipe( |
|
request.prompt, |
|
max_new_tokens=request.max_new_tokens, |
|
temperature=request.temperature, |
|
top_k=request.top_k, |
|
top_p=request.top_p, |
|
do_sample=request.do_sample, |
|
return_full_text=False |
|
) |
|
return {"generated_text": outputs[0]["generated_text"]} |
|
except Exception as e: |
|
logger.error(f"Error generating text: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/test") |
|
async def root(): |
|
return {"message": "API is running!"} |