Spaces:
Running
Running
# app.py | |
import os | |
import logging | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from huggingface_hub import InferenceClient, HfApi | |
from typing import Optional | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="LLM Chat API", | |
description="API for getting chat responses from Llama model", | |
version="1.0.0" | |
) | |
class ChatRequest(BaseModel): | |
text: str | |
class ChatResponse(BaseModel): | |
response: str | |
status: str | |
# Initialize HF client at startup | |
try: | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
logger.warning("HF_TOKEN not found in environment variables") | |
api = HfApi(token=HF_TOKEN) | |
client = InferenceClient( | |
model="meta-llama/Llama-3.2-11B-Vision-Instruct", # You might need to change this | |
token=HF_TOKEN | |
) | |
logger.info("Successfully initialized HuggingFace client") | |
except Exception as e: | |
logger.error(f"Error initializing HuggingFace client: {str(e)}") | |
def llm_chat_response(text: str) -> str: | |
try: | |
logger.info(f"Processing text: {text}") | |
# Direct text generation | |
response = client.text_generation( | |
text + " describe in one line only", | |
max_new_tokens=100, | |
temperature=0.7, | |
repetition_penalty=1.2 | |
) | |
logger.info(f"Generated response: {response}") | |
return response | |
except Exception as e: | |
logger.error(f"Error in llm_chat_response: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def chat(request: ChatRequest): | |
try: | |
logger.info(f"Received chat request with text: {request.text}") | |
response = llm_chat_response(request.text) | |
return ChatResponse(response=response, status="success") | |
except HTTPException as he: | |
logger.error(f"HTTP Exception in chat endpoint: {str(he)}") | |
raise he | |
except Exception as e: | |
logger.error(f"Unexpected error in chat endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "Welcome to the LLM Chat API. Use POST /chat endpoint to get responses."} | |
async def not_found_handler(request, exc): | |
return JSONResponse( | |
status_code=404, | |
content={"error": "Endpoint not found. Please use POST /chat for queries."} | |
) | |
async def method_not_allowed_handler(request, exc): | |
return JSONResponse( | |
status_code=405, | |
content={"error": "Method not allowed. Please check the API documentation."} | |
) |