Fred808 commited on
Commit
0fb92e1
·
verified ·
1 Parent(s): dd25f43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -8
app.py CHANGED
@@ -17,18 +17,29 @@ logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
  # Hugging Face Inference API endpoint for DeepSeek
20
- API_URL = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # Replace with the correct model ID
21
  headers = {"Authorization": f"Bearer {token}"}
22
 
23
  # Define request body schema
24
  class TextGenerationRequest(BaseModel):
25
  prompt: str
26
- max_new_tokens: int = 50 # Reduce this for faster responses
27
- temperature: float = 0.7 # Lower for more deterministic outputs
28
- top_k: int = 50
29
- top_p: float = 0.9
 
30
  do_sample: bool = True
31
 
 
 
 
 
 
 
 
 
 
 
32
  # Define API endpoint
33
  @app.post("/generate-text")
34
  async def generate_text(request: TextGenerationRequest):
@@ -43,6 +54,7 @@ async def generate_text(request: TextGenerationRequest):
43
  "temperature": request.temperature,
44
  "top_k": request.top_k,
45
  "top_p": request.top_p,
 
46
  "do_sample": request.do_sample,
47
  },
48
  }
@@ -55,9 +67,10 @@ async def generate_text(request: TextGenerationRequest):
55
  logger.error(f"API Error: {response.status_code} - {response.text}")
56
  raise HTTPException(status_code=response.status_code, detail=response.text)
57
 
58
- # Extract the generated text from the response
59
  generated_text = response.json()[0]["generated_text"]
60
- return {"generated_text": generated_text}
 
61
  except Exception as e:
62
  logger.error(f"Error generating text: {e}")
63
  raise HTTPException(status_code=500, detail=str(e))
@@ -70,4 +83,4 @@ async def root():
70
  # Add a test endpoint
71
  @app.get("/test")
72
  async def test():
73
- return {"message": "API is running!"}
 
17
  logger = logging.getLogger(__name__)
18
 
19
  # Hugging Face Inference API endpoint for DeepSeek
20
+ API_URL = "https://api-inference.huggingface.co/models/deepseek-ai/deepseek-model" # Replace with the correct model ID
21
  headers = {"Authorization": f"Bearer {token}"}
22
 
23
  # Define request body schema
24
  class TextGenerationRequest(BaseModel):
25
  prompt: str
26
+ max_new_tokens: int = 150 # Increased to allow longer responses
27
+ temperature: float = 0.8 # Slightly increased for more varied responses
28
+ top_k: int = 100 # Broader vocabulary sampling
29
+ top_p: float = 0.92 # Increased diversity
30
+ repetition_penalty: float = 1.2 # Penalizes repetition
31
  do_sample: bool = True
32
 
33
+ # Function to remove repetitive content in the output
34
+ def remove_repetition(text: str) -> str:
35
+ seen = set()
36
+ result = []
37
+ for word in text.split():
38
+ if word.lower() not in seen:
39
+ result.append(word)
40
+ seen.add(word.lower())
41
+ return " ".join(result)
42
+
43
  # Define API endpoint
44
  @app.post("/generate-text")
45
  async def generate_text(request: TextGenerationRequest):
 
54
  "temperature": request.temperature,
55
  "top_k": request.top_k,
56
  "top_p": request.top_p,
57
+ "repetition_penalty": request.repetition_penalty,
58
  "do_sample": request.do_sample,
59
  },
60
  }
 
67
  logger.error(f"API Error: {response.status_code} - {response.text}")
68
  raise HTTPException(status_code=response.status_code, detail=response.text)
69
 
70
+ # Extract and process the generated text
71
  generated_text = response.json()[0]["generated_text"]
72
+ cleaned_text = remove_repetition(generated_text) # Clean up repetition
73
+ return {"generated_text": cleaned_text}
74
  except Exception as e:
75
  logger.error(f"Error generating text: {e}")
76
  raise HTTPException(status_code=500, detail=str(e))
 
83
  # Add a test endpoint
84
  @app.get("/test")
85
  async def test():
86
+ return {"message": "API is running!"}