from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel import os import logging import openai from typing import Optional # Read the NVIDIA API key from environment variables api_key = os.getenv("NVIDIA_API_KEY") if api_key is None: raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.") # Initialize FastAPI app app = FastAPI() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configure OpenAI client to use NVIDIA's API (via OpenAI wrapper) openai.api_key = api_key # Using the NVIDIA API key openai.api_base = "https://integrate.api.nvidia.com/v1" # Set the NVIDIA base URL # Define request body schema class TextGenerationRequest(BaseModel): prompt: str max_new_tokens: int = 1024 temperature: float = 0.4 top_p: float = 0.7 # Define response schema for non-streaming class TextGenerationResponse(BaseModel): generated_text: str # Define API endpoint for non-streaming text generation @app.post("/generate-text", response_model=TextGenerationResponse) async def generate_text(request: TextGenerationRequest): try: logger.info("Generating text with NVIDIA API...") # Prepare the payload for the NVIDIA API request response = openai.ChatCompletion.create( model="meta/llama-3.1-405b-instruct", # Model for NVIDIA API messages=[{"role": "user", "content": request.prompt}], temperature=request.temperature, top_p=request.top_p, max_tokens=request.max_new_tokens, stream=False # Non-streaming response ) # Extract the generated text response_text = response["choices"][0]["message"]["content"] logger.info("Text generation completed successfully.") return {"generated_text": response_text} except Exception as e: logger.error(f"Error generating text: {e}") raise HTTPException(status_code=500, detail=str(e)) # Define API endpoint for streaming text generation @app.post("/generate-text-stream") async def generate_text_stream(request: TextGenerationRequest): async def generate(): try: logger.info("Streaming text with NVIDIA API...") # Prepare the payload for the NVIDIA API request response = openai.ChatCompletion.create( model="meta/llama-3.1-405b-instruct", # Model for NVIDIA API messages=[{"role": "user", "content": request.prompt}], temperature=request.temperature, top_p=request.top_p, max_tokens=request.max_new_tokens, stream=True # Streaming response ) # Stream the response chunks to the client for chunk in response: if isinstance(chunk, dict): # Ensure the chunk is a dictionary # Extract content from each chunk safely content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "") if content: yield content # Stream content to the client else: logger.error(f"Unexpected chunk format: {chunk}") # Log if the chunk format is unexpected logger.info("Text streaming completed successfully.") except Exception as e: logger.error(f"Error streaming text: {e}") yield f"Error: {str(e)}" return StreamingResponse(generate(), media_type="text/plain") # Add a root endpoint for health checks @app.get("/") async def root(): return {"message": "Welcome to the NVIDIA Text Generation API using OpenAI Wrapper!"} # Add a test endpoint @app.get("/test") async def test(): return {"message": "API is running!"}