Fred808 commited on
Commit
7c1d81b
·
verified ·
1 Parent(s): 36267e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -24
app.py CHANGED
@@ -1,12 +1,17 @@
 
 
1
  import os
2
- import requests
3
  import logging
 
4
 
5
  # Read the NVIDIA API key from environment variables
6
  api_key = os.getenv("NVIDIA_API_KEY")
7
  if api_key is None:
8
  raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.")
9
 
 
 
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
@@ -18,32 +23,62 @@ headers = {
18
  "Content-Type": "application/json"
19
  }
20
 
21
- # Define request payload
22
- payload = {
23
- "model": "meta/llama-3.1-405b-instruct", # Model for NVIDIA's text generation
24
- "messages": [{"role": "user", "content": "Write a limerick about the wonders of GPU computing."}],
25
- "temperature": 0.2,
26
- "top_p": 0.7,
27
- "max_tokens": 1024,
28
- "stream": True
29
- }
 
 
 
 
30
 
31
- # Call NVIDIA's API for text generation
32
- try:
33
- logger.info("Generating text with NVIDIA API...")
34
- response = requests.post(f"{base_url}/chat/completions", headers=headers, json=payload, stream=True)
 
 
 
 
 
35
 
36
- if response.status_code == 200:
37
- # Stream the response
 
 
 
 
 
38
  response_text = ""
39
  for chunk in response.iter_lines():
40
  if chunk:
41
  data = chunk.decode("utf-8")
42
- # Extract the content from the response (adjust based on actual API response structure)
43
- if "content" in data:
44
- response_text += data["choices"][0]["delta"].get("content", "")
45
- print(response_text, end="") # Print content as it's received
46
- else:
47
- logger.error(f"Error: {response.status_code} - {response.text}")
48
- except Exception as e:
49
- logger.error(f"Error generating text: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  import os
 
4
  import logging
5
+ import requests
6
 
7
  # Read the NVIDIA API key from environment variables
8
  api_key = os.getenv("NVIDIA_API_KEY")
9
  if api_key is None:
10
  raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.")
11
 
12
+ # Initialize FastAPI app
13
+ app = FastAPI()
14
+
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
 
23
  "Content-Type": "application/json"
24
  }
25
 
26
+ # Define request body schema
27
+ class TextGenerationRequest(BaseModel):
28
+ prompt: str
29
+ max_new_tokens: int = 1024
30
+ temperature: float = 0.4
31
+ top_p: float = 0.7
32
+ stream: bool = True
33
+
34
+ # Define API endpoint to generate text
35
+ @app.post("/generate-text")
36
+ async def generate_text(request: TextGenerationRequest):
37
+ try:
38
+ logger.info("Generating text with NVIDIA API...")
39
 
40
+ # Prepare the payload for the NVIDIA API request
41
+ payload = {
42
+ "model": "meta/llama-3.1-405b-instruct", # NVIDIA-specific model
43
+ "messages": [{"role": "user", "content": request.prompt}],
44
+ "temperature": request.temperature,
45
+ "top_p": request.top_p,
46
+ "max_tokens": request.max_new_tokens,
47
+ "stream": request.stream
48
+ }
49
 
50
+ # Send POST request to NVIDIA API (streaming enabled)
51
+ response = requests.post(f"{base_url}/chat/completions", headers=headers, json=payload, stream=True)
52
+
53
+ if response.status_code != 200:
54
+ raise HTTPException(status_code=response.status_code, detail=f"Error: {response.text}")
55
+
56
+ # Process the streaming response
57
  response_text = ""
58
  for chunk in response.iter_lines():
59
  if chunk:
60
  data = chunk.decode("utf-8")
61
+ # Assuming the API response contains 'choices' and 'delta'
62
+ try:
63
+ content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
64
+ if content:
65
+ response_text += content
66
+ print(content, end="") # Print the content to stream it out
67
+ except Exception as e:
68
+ logger.error(f"Error processing chunk: {e}")
69
+
70
+ return {"generated_text": response_text}
71
+
72
+ except Exception as e:
73
+ logger.error(f"Error generating text: {e}")
74
+ raise HTTPException(status_code=500, detail=str(e))
75
+
76
+ # Add a root endpoint for health checks
77
+ @app.get("/")
78
+ async def root():
79
+ return {"message": "Welcome to the NVIDIA Text Generation API!"}
80
+
81
+ # Add a test endpoint
82
+ @app.get("/test")
83
+ async def test():
84
+ return {"message": "API is running!"}