Fred808 commited on
Commit
acbb541
·
verified ·
1 Parent(s): df5b28a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -25,6 +25,10 @@ logger = logging.getLogger(__name__)
25
  model_id = "EleutherAI/gpt-neo-125M" # Use GPT-Neo-125M for faster performance
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
 
 
 
 
 
28
  # Load the model without quantization for CPU
29
  logger.info("Loading model...")
30
  model = AutoModelForCausalLM.from_pretrained(
@@ -71,6 +75,11 @@ async def generate_text(request: TextGenerationRequest):
71
  raise HTTPException(status_code=500, detail=str(e))
72
 
73
  # Add a root endpoint for health checks
74
- @app.get("/test")
75
  async def root():
 
 
 
 
 
76
  return {"message": "API is running!"}
 
25
  model_id = "EleutherAI/gpt-neo-125M" # Use GPT-Neo-125M for faster performance
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
 
28
+ # Set pad_token if it doesn't exist
29
+ if tokenizer.pad_token is None:
30
+ tokenizer.pad_token = tokenizer.eos_token # Use eos_token as pad_token
31
+
32
  # Load the model without quantization for CPU
33
  logger.info("Loading model...")
34
  model = AutoModelForCausalLM.from_pretrained(
 
75
  raise HTTPException(status_code=500, detail=str(e))
76
 
77
  # Add a root endpoint for health checks
78
+ @app.get("/")
79
  async def root():
80
+ return {"message": "Welcome to the Text Generation API! Use /generate-text to generate text."}
81
+
82
+ # Add a test endpoint
83
+ @app.get("/test")
84
+ async def test():
85
  return {"message": "API is running!"}