Fred808 commited on
Commit
03ac765
·
verified ·
1 Parent(s): 7c1d81b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -33
app.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import os
4
  import logging
5
- import requests
6
 
7
  # Read the NVIDIA API key from environment variables
8
  api_key = os.getenv("NVIDIA_API_KEY")
@@ -16,12 +16,9 @@ app = FastAPI()
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- # NVIDIA API configuration
20
- base_url = "https://integrate.api.nvidia.com/v1"
21
- headers = {
22
- "Authorization": f"Bearer {api_key}",
23
- "Content-Type": "application/json"
24
- }
25
 
26
  # Define request body schema
27
  class TextGenerationRequest(BaseModel):
@@ -38,35 +35,30 @@ async def generate_text(request: TextGenerationRequest):
38
  logger.info("Generating text with NVIDIA API...")
39
 
40
  # Prepare the payload for the NVIDIA API request
41
- payload = {
42
- "model": "meta/llama-3.1-405b-instruct", # NVIDIA-specific model
43
- "messages": [{"role": "user", "content": request.prompt}],
44
- "temperature": request.temperature,
45
- "top_p": request.top_p,
46
- "max_tokens": request.max_new_tokens,
47
- "stream": request.stream
48
- }
49
 
50
- # Send POST request to NVIDIA API (streaming enabled)
51
- response = requests.post(f"{base_url}/chat/completions", headers=headers, json=payload, stream=True)
52
-
53
- if response.status_code != 200:
54
- raise HTTPException(status_code=response.status_code, detail=f"Error: {response.text}")
55
-
56
- # Process the streaming response
57
  response_text = ""
58
- for chunk in response.iter_lines():
59
- if chunk:
60
- data = chunk.decode("utf-8")
61
- # Assuming the API response contains 'choices' and 'delta'
62
- try:
63
- content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
64
  if content:
65
  response_text += content
66
- print(content, end="") # Print the content to stream it out
67
- except Exception as e:
68
- logger.error(f"Error processing chunk: {e}")
69
-
 
 
70
  return {"generated_text": response_text}
71
 
72
  except Exception as e:
@@ -76,7 +68,7 @@ async def generate_text(request: TextGenerationRequest):
76
  # Add a root endpoint for health checks
77
  @app.get("/")
78
  async def root():
79
- return {"message": "Welcome to the NVIDIA Text Generation API!"}
80
 
81
  # Add a test endpoint
82
  @app.get("/test")
 
2
  from pydantic import BaseModel
3
  import os
4
  import logging
5
+ import openai
6
 
7
  # Read the NVIDIA API key from environment variables
8
  api_key = os.getenv("NVIDIA_API_KEY")
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Configure OpenAI client to use NVIDIA's API (via OpenAI wrapper)
20
+ openai.api_key = api_key # Using the NVIDIA API key
21
+ openai.api_base = "https://integrate.api.nvidia.com/v1" # Set the NVIDIA base URL
 
 
 
22
 
23
  # Define request body schema
24
  class TextGenerationRequest(BaseModel):
 
35
  logger.info("Generating text with NVIDIA API...")
36
 
37
  # Prepare the payload for the NVIDIA API request
38
+ response = openai.ChatCompletion.create(
39
+ model="meta/llama-3.1-405b-instruct", # Model for NVIDIA API
40
+ messages=[{"role": "user", "content": request.prompt}],
41
+ temperature=request.temperature,
42
+ top_p=request.top_p,
43
+ max_tokens=request.max_new_tokens,
44
+ stream=request.stream
45
+ )
46
 
 
 
 
 
 
 
 
47
  response_text = ""
48
+ if request.stream:
49
+ # Handle streaming response
50
+ for chunk in response:
51
+ if isinstance(chunk, dict): # Ensure the chunk is a dictionary
52
+ # Extract content from each chunk safely
53
+ content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
54
  if content:
55
  response_text += content
56
+ print(content, end="") # Print content as it is streamed
57
+ else:
58
+ logger.error(f"Unexpected chunk format: {chunk}") # Log if the chunk format is unexpected
59
+ else:
60
+ response_text = response["choices"][0]["message"]["content"]
61
+
62
  return {"generated_text": response_text}
63
 
64
  except Exception as e:
 
68
  # Add a root endpoint for health checks
69
  @app.get("/")
70
  async def root():
71
+ return {"message": "Welcome to the NVIDIA Text Generation API using OpenAI Wrapper!"}
72
 
73
  # Add a test endpoint
74
  @app.get("/test")