Fred808 commited on
Commit
16de63c
·
verified ·
1 Parent(s): 28bfc95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -49
app.py CHANGED
@@ -1,13 +1,13 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- import requests
4
  import os
5
  import logging
 
6
 
7
- # Read the Hugging Face token from the environment variable
8
- token = os.getenv("HUGGING_FACE_HUB_TOKEN")
9
- if token is None:
10
- raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.")
11
 
12
  # Initialize FastAPI app
13
  app = FastAPI()
@@ -16,61 +16,43 @@ app = FastAPI()
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- # Hugging Face Inference API endpoint for DeepSeek
20
- API_URL = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # Replace with the correct model ID
21
- headers = {"Authorization": f"Bearer {token}"}
22
 
23
  # Define request body schema
24
  class TextGenerationRequest(BaseModel):
25
  prompt: str
26
- max_new_tokens: int = 150 # Increased to allow longer responses
27
- temperature: float = 0.8 # Slightly increased for more varied responses
28
- top_k: int = 100 # Broader vocabulary sampling
29
- top_p: float = 0.92 # Increased diversity
30
- repetition_penalty: float = 1.2 # Penalizes repetition
31
- do_sample: bool = True
32
-
33
- # Function to remove repetitive content in the output
34
- def remove_repetition(text: str) -> str:
35
- seen = set()
36
- result = []
37
- for word in text.split():
38
- if word.lower() not in seen:
39
- result.append(word)
40
- seen.add(word.lower())
41
- return " ".join(result)
42
 
43
  # Define API endpoint
44
  @app.post("/generate-text")
45
  async def generate_text(request: TextGenerationRequest):
46
  try:
47
  logger.info("Generating text...")
48
-
49
- # Prepare the payload for the Hugging Face Inference API
50
- payload = {
51
- "inputs": request.prompt,
52
- "parameters": {
53
- "max_new_tokens": request.max_new_tokens,
54
- "temperature": request.temperature,
55
- "top_k": request.top_k,
56
- "top_p": request.top_p,
57
- "repetition_penalty": request.repetition_penalty,
58
- "do_sample": request.do_sample,
59
- },
60
- }
61
 
62
- # Send request to the Hugging Face Inference API
63
- response = requests.post(API_URL, headers=headers, json=payload)
64
-
65
- # Check for errors in the response
66
- if response.status_code != 200:
67
- logger.error(f"API Error: {response.status_code} - {response.text}")
68
- raise HTTPException(status_code=response.status_code, detail=response.text)
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Extract and process the generated text
71
- generated_text = response.json()[0]["generated_text"]
72
- cleaned_text = remove_repetition(generated_text) # Clean up repetition
73
- return {"generated_text": cleaned_text}
74
  except Exception as e:
75
  logger.error(f"Error generating text: {e}")
76
  raise HTTPException(status_code=500, detail=str(e))
@@ -78,7 +60,7 @@ async def generate_text(request: TextGenerationRequest):
78
  # Add a root endpoint for health checks
79
  @app.get("/")
80
  async def root():
81
- return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."}
82
 
83
  # Add a test endpoint
84
  @app.get("/test")
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
 
3
  import os
4
  import logging
5
+ import openai
6
 
7
+ # Read the NVIDIA API key from environment variables
8
+ api_key = os.getenv("NVIDIA_API_KEY")
9
+ if api_key is None:
10
+ raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.")
11
 
12
  # Initialize FastAPI app
13
  app = FastAPI()
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ # NVIDIA API configuration
20
+ openai.api_key = api_key
21
+ openai.base_url = "https://integrate.api.nvidia.com/v1"
22
 
23
  # Define request body schema
24
  class TextGenerationRequest(BaseModel):
25
  prompt: str
26
+ max_new_tokens: int = 1024
27
+ temperature: float = 0.4
28
+ top_p: float = 0.7
29
+ stream: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Define API endpoint
32
  @app.post("/generate-text")
33
  async def generate_text(request: TextGenerationRequest):
34
  try:
35
  logger.info("Generating text...")
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Generate response from NVIDIA API
38
+ response = openai.ChatCompletion.create(
39
+ model="meta/llama-3.1-405b-instruct",
40
+ messages=[{"role": "user", "content": request.prompt}],
41
+ temperature=request.temperature,
42
+ top_p=request.top_p,
43
+ max_tokens=request.max_new_tokens,
44
+ stream=request.stream,
45
+ )
46
+
47
+ response_text = ""
48
+ if request.stream:
49
+ for chunk in response:
50
+ if chunk.choices[0].delta.get("content"):
51
+ response_text += chunk.choices[0].delta.content
52
+ else:
53
+ response_text = response["choices"][0]["message"]["content"]
54
 
55
+ return {"generated_text": response_text}
 
 
 
56
  except Exception as e:
57
  logger.error(f"Error generating text: {e}")
58
  raise HTTPException(status_code=500, detail=str(e))
 
60
  # Add a root endpoint for health checks
61
  @app.get("/")
62
  async def root():
63
+ return {"message": "Welcome Fred808 GPT"}
64
 
65
  # Add a test endpoint
66
  @app.get("/test")