Fred808 commited on
Commit
9ddd59a
·
verified ·
1 Parent(s): acbb541

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -39
app.py CHANGED
@@ -1,9 +1,7 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
- from huggingface_hub import login
5
  import os
6
- import torch
7
  import logging
8
 
9
  # Read the Hugging Face token from the environment variable
@@ -11,9 +9,6 @@ token = os.getenv("HUGGING_FACE_HUB_TOKEN")
11
  if token is None:
12
  raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.")
13
 
14
- # Log in with the token
15
- login(token=token)
16
-
17
  # Initialize FastAPI app
18
  app = FastAPI()
19
 
@@ -21,28 +16,9 @@ app = FastAPI()
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
- # Load a smaller model (GPT-Neo-125M) and tokenizer
25
- model_id = "EleutherAI/gpt-neo-125M" # Use GPT-Neo-125M for faster performance
26
- tokenizer = AutoTokenizer.from_pretrained(model_id)
27
-
28
- # Set pad_token if it doesn't exist
29
- if tokenizer.pad_token is None:
30
- tokenizer.pad_token = tokenizer.eos_token # Use eos_token as pad_token
31
-
32
- # Load the model without quantization for CPU
33
- logger.info("Loading model...")
34
- model = AutoModelForCausalLM.from_pretrained(
35
- model_id,
36
- torch_dtype=torch.float32, # Use FP32 for CPU compatibility
37
- device_map="auto" # Automatically offload to available devices
38
- )
39
-
40
- # Create a text generation pipeline
41
- pipe = pipeline(
42
- "text-generation",
43
- model=model,
44
- tokenizer=tokenizer
45
- )
46
 
47
  # Define request body schema
48
  class TextGenerationRequest(BaseModel):
@@ -59,17 +35,25 @@ async def generate_text(request: TextGenerationRequest):
59
  try:
60
  logger.info("Generating text...")
61
 
62
- # Generate text using the pipeline with the user's prompt
63
- outputs = pipe(
64
- request.prompt, # Use the user's prompt directly
65
- max_new_tokens=request.max_new_tokens,
66
- temperature=request.temperature,
67
- top_k=request.top_k,
68
- top_p=request.top_p,
69
- do_sample=request.do_sample,
70
- return_full_text=False # Exclude the input prompt from the output
71
- )
72
- return {"generated_text": outputs[0]["generated_text"]}
 
 
 
 
 
 
 
 
73
  except Exception as e:
74
  logger.error(f"Error generating text: {e}")
75
  raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ import requests
 
4
  import os
 
5
  import logging
6
 
7
  # Read the Hugging Face token from the environment variable
 
9
  if token is None:
10
  raise ValueError("Hugging Face token not found in environment variables. Please set the HUGGING_FACE_HUB_TOKEN secret in Hugging Face Spaces.")
11
 
 
 
 
12
  # Initialize FastAPI app
13
  app = FastAPI()
14
 
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Hugging Face Inference API endpoint for BLOOM-7B
20
+ API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom-7b1" # Use BLOOM-7B
21
+ headers = {"Authorization": f"Bearer {token}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Define request body schema
24
  class TextGenerationRequest(BaseModel):
 
35
  try:
36
  logger.info("Generating text...")
37
 
38
+ # Prepare the payload for the Hugging Face Inference API
39
+ payload = {
40
+ "inputs": request.prompt,
41
+ "parameters": {
42
+ "max_new_tokens": request.max_new_tokens,
43
+ "temperature": request.temperature,
44
+ "top_k": request.top_k,
45
+ "top_p": request.top_p,
46
+ "do_sample": request.do_sample,
47
+ },
48
+ }
49
+
50
+ # Send request to the Hugging Face Inference API
51
+ response = requests.post(API_URL, headers=headers, json=payload)
52
+ response.raise_for_status() # Raise an error for bad responses (4xx or 5xx)
53
+
54
+ # Extract the generated text from the response
55
+ generated_text = response.json()[0]["generated_text"]
56
+ return {"generated_text": generated_text}
57
  except Exception as e:
58
  logger.error(f"Error generating text: {e}")
59
  raise HTTPException(status_code=500, detail=str(e))