from fastapi import FastAPI, File, UploadFile, HTTPException import requests import base64 from pydantic import BaseModel from typing import Optional app = FastAPI() # NVIDIA API endpoint and API key NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions" API_KEY = "your_nvidia_api_key_here" # Replace with your actual API key # Request model for text-based input class TextRequest(BaseModel): message: str max_tokens: Optional[int] = 512 temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 # Function to call the NVIDIA API def call_nvidia_api(payload: dict): headers = { "Authorization": f"Bearer {API_KEY}", "Accept": "application/json", } response = requests.post(NVIDIA_API_URL, headers=headers, json=payload) if response.status_code != 200: raise HTTPException(status_code=response.status_code, detail="NVIDIA API request failed") return response.json() # Endpoint for text-based input @app.post("/chat/text") async def chat_with_text(request: TextRequest): payload = { "model": "meta/llama-3.2-90b-vision-instruct", "messages": [{"role": "user", "content": request.message}], "max_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, "stream": False, } try: response = call_nvidia_api(payload) return {"response": response["choices"][0]["message"]["content"]} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Endpoint for image-based input @app.post("/chat/image") async def chat_with_image(file: UploadFile = File(...)): # Read and encode the image file to base64 image_data = await file.read() base64_image = base64.b64encode(image_data).decode("utf-8") # Prepare the payload for the NVIDIA API payload = { "model": "meta/llama-3.2-90b-vision-instruct", "messages": [ { "role": "user", "content": f'What is in this image? ', } ], "max_tokens": 512, "temperature": 1.0, "top_p": 1.0, "stream": False, } try: response = call_nvidia_api(payload) return {"response": response["choices"][0]["message"]["content"]} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Root endpoint @app.get("/") async def root(): return {"message": "Welcome to the NVIDIA API FastAPI wrapper!"}