Fred808 commited on
Commit
d501abc
·
verified ·
1 Parent(s): 03ac765

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -16
app.py CHANGED
@@ -1,8 +1,10 @@
1
  from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
  import os
4
  import logging
5
  import openai
 
6
 
7
  # Read the NVIDIA API key from environment variables
8
  api_key = os.getenv("NVIDIA_API_KEY")
@@ -26,10 +28,13 @@ class TextGenerationRequest(BaseModel):
26
  max_new_tokens: int = 1024
27
  temperature: float = 0.4
28
  top_p: float = 0.7
29
- stream: bool = True
30
 
31
- # Define API endpoint to generate text
32
- @app.post("/generate-text")
 
 
 
 
33
  async def generate_text(request: TextGenerationRequest):
34
  try:
35
  logger.info("Generating text with NVIDIA API...")
@@ -41,29 +46,51 @@ async def generate_text(request: TextGenerationRequest):
41
  temperature=request.temperature,
42
  top_p=request.top_p,
43
  max_tokens=request.max_new_tokens,
44
- stream=request.stream
45
  )
46
 
47
- response_text = ""
48
- if request.stream:
49
- # Handle streaming response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  for chunk in response:
51
  if isinstance(chunk, dict): # Ensure the chunk is a dictionary
52
  # Extract content from each chunk safely
53
  content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
54
  if content:
55
- response_text += content
56
- print(content, end="") # Print content as it is streamed
57
  else:
58
  logger.error(f"Unexpected chunk format: {chunk}") # Log if the chunk format is unexpected
59
- else:
60
- response_text = response["choices"][0]["message"]["content"]
61
 
62
- return {"generated_text": response_text}
 
 
 
63
 
64
- except Exception as e:
65
- logger.error(f"Error generating text: {e}")
66
- raise HTTPException(status_code=500, detail=str(e))
67
 
68
  # Add a root endpoint for health checks
69
  @app.get("/")
@@ -73,4 +100,4 @@ async def root():
73
  # Add a test endpoint
74
  @app.get("/test")
75
  async def test():
76
- return {"message": "API is running!"}
 
1
  from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  import os
5
  import logging
6
  import openai
7
+ from typing import Optional
8
 
9
  # Read the NVIDIA API key from environment variables
10
  api_key = os.getenv("NVIDIA_API_KEY")
 
28
  max_new_tokens: int = 1024
29
  temperature: float = 0.4
30
  top_p: float = 0.7
 
31
 
32
+ # Define response schema for non-streaming
33
+ class TextGenerationResponse(BaseModel):
34
+ generated_text: str
35
+
36
+ # Define API endpoint for non-streaming text generation
37
+ @app.post("/generate-text", response_model=TextGenerationResponse)
38
  async def generate_text(request: TextGenerationRequest):
39
  try:
40
  logger.info("Generating text with NVIDIA API...")
 
46
  temperature=request.temperature,
47
  top_p=request.top_p,
48
  max_tokens=request.max_new_tokens,
49
+ stream=False # Non-streaming response
50
  )
51
 
52
+ # Extract the generated text
53
+ response_text = response["choices"][0]["message"]["content"]
54
+ logger.info("Text generation completed successfully.")
55
+ return {"generated_text": response_text}
56
+
57
+ except Exception as e:
58
+ logger.error(f"Error generating text: {e}")
59
+ raise HTTPException(status_code=500, detail=str(e))
60
+
61
+ # Define API endpoint for streaming text generation
62
+ @app.post("/generate-text-stream")
63
+ async def generate_text_stream(request: TextGenerationRequest):
64
+ async def generate():
65
+ try:
66
+ logger.info("Streaming text with NVIDIA API...")
67
+
68
+ # Prepare the payload for the NVIDIA API request
69
+ response = openai.ChatCompletion.create(
70
+ model="meta/llama-3.1-405b-instruct", # Model for NVIDIA API
71
+ messages=[{"role": "user", "content": request.prompt}],
72
+ temperature=request.temperature,
73
+ top_p=request.top_p,
74
+ max_tokens=request.max_new_tokens,
75
+ stream=True # Streaming response
76
+ )
77
+
78
+ # Stream the response chunks to the client
79
  for chunk in response:
80
  if isinstance(chunk, dict): # Ensure the chunk is a dictionary
81
  # Extract content from each chunk safely
82
  content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
83
  if content:
84
+ yield content # Stream content to the client
 
85
  else:
86
  logger.error(f"Unexpected chunk format: {chunk}") # Log if the chunk format is unexpected
 
 
87
 
88
+ logger.info("Text streaming completed successfully.")
89
+ except Exception as e:
90
+ logger.error(f"Error streaming text: {e}")
91
+ yield f"Error: {str(e)}"
92
 
93
+ return StreamingResponse(generate(), media_type="text/plain")
 
 
94
 
95
  # Add a root endpoint for health checks
96
  @app.get("/")
 
100
  # Add a test endpoint
101
  @app.get("/test")
102
  async def test():
103
+ return {"message": "API is running!"}