808-GPT2 / app.py
Fred808's picture
Update app.py
d501abc verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import os
import logging
import openai
from typing import Optional
# Read the NVIDIA API key from environment variables
api_key = os.getenv("NVIDIA_API_KEY")
if api_key is None:
raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.")
# Initialize FastAPI app
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configure OpenAI client to use NVIDIA's API (via OpenAI wrapper)
openai.api_key = api_key # Using the NVIDIA API key
openai.api_base = "https://integrate.api.nvidia.com/v1" # Set the NVIDIA base URL
# Define request body schema
class TextGenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 1024
temperature: float = 0.4
top_p: float = 0.7
# Define response schema for non-streaming
class TextGenerationResponse(BaseModel):
generated_text: str
# Define API endpoint for non-streaming text generation
@app.post("/generate-text", response_model=TextGenerationResponse)
async def generate_text(request: TextGenerationRequest):
try:
logger.info("Generating text with NVIDIA API...")
# Prepare the payload for the NVIDIA API request
response = openai.ChatCompletion.create(
model="meta/llama-3.1-405b-instruct", # Model for NVIDIA API
messages=[{"role": "user", "content": request.prompt}],
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_new_tokens,
stream=False # Non-streaming response
)
# Extract the generated text
response_text = response["choices"][0]["message"]["content"]
logger.info("Text generation completed successfully.")
return {"generated_text": response_text}
except Exception as e:
logger.error(f"Error generating text: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Define API endpoint for streaming text generation
@app.post("/generate-text-stream")
async def generate_text_stream(request: TextGenerationRequest):
async def generate():
try:
logger.info("Streaming text with NVIDIA API...")
# Prepare the payload for the NVIDIA API request
response = openai.ChatCompletion.create(
model="meta/llama-3.1-405b-instruct", # Model for NVIDIA API
messages=[{"role": "user", "content": request.prompt}],
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_new_tokens,
stream=True # Streaming response
)
# Stream the response chunks to the client
for chunk in response:
if isinstance(chunk, dict): # Ensure the chunk is a dictionary
# Extract content from each chunk safely
content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content:
yield content # Stream content to the client
else:
logger.error(f"Unexpected chunk format: {chunk}") # Log if the chunk format is unexpected
logger.info("Text streaming completed successfully.")
except Exception as e:
logger.error(f"Error streaming text: {e}")
yield f"Error: {str(e)}"
return StreamingResponse(generate(), media_type="text/plain")
# Add a root endpoint for health checks
@app.get("/")
async def root():
return {"message": "Welcome to the NVIDIA Text Generation API using OpenAI Wrapper!"}
# Add a test endpoint
@app.get("/test")
async def test():
return {"message": "API is running!"}