File size: 3,869 Bytes
7c1d81b d501abc 7c1d81b 624df6b e0e5738 03ac765 d501abc 64c0b0e 16de63c 624df6b 7c1d81b e0e5738 03ac765 36267e8 7c1d81b d501abc 7c1d81b 36267e8 7c1d81b 03ac765 d501abc 03ac765 36267e8 d501abc 03ac765 7c1d81b d501abc 03ac765 d501abc 7c1d81b d501abc 7c1d81b 03ac765 7c1d81b d501abc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import os
import logging
import openai
from typing import Optional
# Read the NVIDIA API key from environment variables
api_key = os.getenv("NVIDIA_API_KEY")
if api_key is None:
raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.")
# Initialize FastAPI app
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configure OpenAI client to use NVIDIA's API (via OpenAI wrapper)
openai.api_key = api_key # Using the NVIDIA API key
openai.api_base = "https://integrate.api.nvidia.com/v1" # Set the NVIDIA base URL
# Define request body schema
class TextGenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 1024
temperature: float = 0.4
top_p: float = 0.7
# Define response schema for non-streaming
class TextGenerationResponse(BaseModel):
generated_text: str
# Define API endpoint for non-streaming text generation
@app.post("/generate-text", response_model=TextGenerationResponse)
async def generate_text(request: TextGenerationRequest):
try:
logger.info("Generating text with NVIDIA API...")
# Prepare the payload for the NVIDIA API request
response = openai.ChatCompletion.create(
model="meta/llama-3.1-405b-instruct", # Model for NVIDIA API
messages=[{"role": "user", "content": request.prompt}],
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_new_tokens,
stream=False # Non-streaming response
)
# Extract the generated text
response_text = response["choices"][0]["message"]["content"]
logger.info("Text generation completed successfully.")
return {"generated_text": response_text}
except Exception as e:
logger.error(f"Error generating text: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Define API endpoint for streaming text generation
@app.post("/generate-text-stream")
async def generate_text_stream(request: TextGenerationRequest):
async def generate():
try:
logger.info("Streaming text with NVIDIA API...")
# Prepare the payload for the NVIDIA API request
response = openai.ChatCompletion.create(
model="meta/llama-3.1-405b-instruct", # Model for NVIDIA API
messages=[{"role": "user", "content": request.prompt}],
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_new_tokens,
stream=True # Streaming response
)
# Stream the response chunks to the client
for chunk in response:
if isinstance(chunk, dict): # Ensure the chunk is a dictionary
# Extract content from each chunk safely
content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content:
yield content # Stream content to the client
else:
logger.error(f"Unexpected chunk format: {chunk}") # Log if the chunk format is unexpected
logger.info("Text streaming completed successfully.")
except Exception as e:
logger.error(f"Error streaming text: {e}")
yield f"Error: {str(e)}"
return StreamingResponse(generate(), media_type="text/plain")
# Add a root endpoint for health checks
@app.get("/")
async def root():
return {"message": "Welcome to the NVIDIA Text Generation API using OpenAI Wrapper!"}
# Add a test endpoint
@app.get("/test")
async def test():
return {"message": "API is running!"} |