File size: 3,869 Bytes
7c1d81b
d501abc
7c1d81b
624df6b
e0e5738
03ac765
d501abc
64c0b0e
16de63c
 
 
 
624df6b
7c1d81b
 
 
e0e5738
 
 
 
03ac765
 
 
36267e8
7c1d81b
 
 
 
 
 
 
d501abc
 
 
 
 
 
7c1d81b
 
 
36267e8
7c1d81b
03ac765
 
 
 
 
 
d501abc
03ac765
36267e8
d501abc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03ac765
 
 
 
7c1d81b
d501abc
03ac765
 
 
d501abc
 
 
 
7c1d81b
d501abc
7c1d81b
 
 
 
03ac765
7c1d81b
 
 
 
d501abc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import os
import logging
import openai
from typing import Optional

# Read the NVIDIA API key from environment variables
api_key = os.getenv("NVIDIA_API_KEY")
if api_key is None:
    raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.")

# Initialize FastAPI app
app = FastAPI()

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configure OpenAI client to use NVIDIA's API (via OpenAI wrapper)
openai.api_key = api_key  # Using the NVIDIA API key
openai.api_base = "https://integrate.api.nvidia.com/v1"  # Set the NVIDIA base URL

# Define request body schema
class TextGenerationRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 1024
    temperature: float = 0.4
    top_p: float = 0.7

# Define response schema for non-streaming
class TextGenerationResponse(BaseModel):
    generated_text: str

# Define API endpoint for non-streaming text generation
@app.post("/generate-text", response_model=TextGenerationResponse)
async def generate_text(request: TextGenerationRequest):
    try:
        logger.info("Generating text with NVIDIA API...")

        # Prepare the payload for the NVIDIA API request
        response = openai.ChatCompletion.create(
            model="meta/llama-3.1-405b-instruct",  # Model for NVIDIA API
            messages=[{"role": "user", "content": request.prompt}],
            temperature=request.temperature,
            top_p=request.top_p,
            max_tokens=request.max_new_tokens,
            stream=False  # Non-streaming response
        )

        # Extract the generated text
        response_text = response["choices"][0]["message"]["content"]
        logger.info("Text generation completed successfully.")
        return {"generated_text": response_text}

    except Exception as e:
        logger.error(f"Error generating text: {e}")
        raise HTTPException(status_code=500, detail=str(e))

# Define API endpoint for streaming text generation
@app.post("/generate-text-stream")
async def generate_text_stream(request: TextGenerationRequest):
    async def generate():
        try:
            logger.info("Streaming text with NVIDIA API...")

            # Prepare the payload for the NVIDIA API request
            response = openai.ChatCompletion.create(
                model="meta/llama-3.1-405b-instruct",  # Model for NVIDIA API
                messages=[{"role": "user", "content": request.prompt}],
                temperature=request.temperature,
                top_p=request.top_p,
                max_tokens=request.max_new_tokens,
                stream=True  # Streaming response
            )

            # Stream the response chunks to the client
            for chunk in response:
                if isinstance(chunk, dict):  # Ensure the chunk is a dictionary
                    # Extract content from each chunk safely
                    content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
                    if content:
                        yield content  # Stream content to the client
                else:
                    logger.error(f"Unexpected chunk format: {chunk}")  # Log if the chunk format is unexpected

            logger.info("Text streaming completed successfully.")
        except Exception as e:
            logger.error(f"Error streaming text: {e}")
            yield f"Error: {str(e)}"

    return StreamingResponse(generate(), media_type="text/plain")

# Add a root endpoint for health checks
@app.get("/")
async def root():
    return {"message": "Welcome to the NVIDIA Text Generation API using OpenAI Wrapper!"}

# Add a test endpoint
@app.get("/test")
async def test():
    return {"message": "API is running!"}