|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel |
|
import os |
|
import logging |
|
import openai |
|
from typing import Optional |
|
|
|
|
|
api_key = os.getenv("NVIDIA_API_KEY") |
|
if api_key is None: |
|
raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
openai.api_key = api_key |
|
openai.api_base = "https://integrate.api.nvidia.com/v1" |
|
|
|
|
|
class TextGenerationRequest(BaseModel): |
|
prompt: str |
|
max_new_tokens: int = 1024 |
|
temperature: float = 0.4 |
|
top_p: float = 0.7 |
|
|
|
|
|
class TextGenerationResponse(BaseModel): |
|
generated_text: str |
|
|
|
|
|
@app.post("/generate-text", response_model=TextGenerationResponse) |
|
async def generate_text(request: TextGenerationRequest): |
|
try: |
|
logger.info("Generating text with NVIDIA API...") |
|
|
|
|
|
response = openai.ChatCompletion.create( |
|
model="meta/llama-3.1-405b-instruct", |
|
messages=[{"role": "user", "content": request.prompt}], |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
max_tokens=request.max_new_tokens, |
|
stream=False |
|
) |
|
|
|
|
|
response_text = response["choices"][0]["message"]["content"] |
|
logger.info("Text generation completed successfully.") |
|
return {"generated_text": response_text} |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating text: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/generate-text-stream") |
|
async def generate_text_stream(request: TextGenerationRequest): |
|
async def generate(): |
|
try: |
|
logger.info("Streaming text with NVIDIA API...") |
|
|
|
|
|
response = openai.ChatCompletion.create( |
|
model="meta/llama-3.1-405b-instruct", |
|
messages=[{"role": "user", "content": request.prompt}], |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
max_tokens=request.max_new_tokens, |
|
stream=True |
|
) |
|
|
|
|
|
for chunk in response: |
|
if isinstance(chunk, dict): |
|
|
|
content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "") |
|
if content: |
|
yield content |
|
else: |
|
logger.error(f"Unexpected chunk format: {chunk}") |
|
|
|
logger.info("Text streaming completed successfully.") |
|
except Exception as e: |
|
logger.error(f"Error streaming text: {e}") |
|
yield f"Error: {str(e)}" |
|
|
|
return StreamingResponse(generate(), media_type="text/plain") |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Welcome to the NVIDIA Text Generation API using OpenAI Wrapper!"} |
|
|
|
|
|
@app.get("/test") |
|
async def test(): |
|
return {"message": "API is running!"} |