Fred808 commited on
Commit
e0e5738
·
verified ·
1 Parent(s): 4b7f924

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -2,12 +2,17 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import torch
 
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
- # Load the latest Falcon-7B model with 8-bit quantization (if CUDA is available)
10
- model_id = "tiiuae/falcon-7b-instruct" # Update this if there's a newer version
 
 
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_id)
12
 
13
  # Check if CUDA is available
@@ -15,16 +20,16 @@ if torch.cuda.is_available():
15
  # Load the model with 8-bit quantization for GPU
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_id,
18
- load_in_8bit=True, # Use 8-bit quantization for GPU
19
- device_map="auto", # Automatically map the model to available devices
20
- trust_remote_code=True # Required for Falcon models
21
  )
22
  else:
23
  # Fallback to CPU or full precision
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
- device_map="auto", # Automatically map the model to available devices
27
- trust_remote_code=True # Required for Falcon models
28
  )
29
 
30
  # Create a text generation pipeline
@@ -43,7 +48,7 @@ class TextGenerationRequest(BaseModel):
43
  @app.post("/generate-text")
44
  async def generate_text(request: TextGenerationRequest):
45
  try:
46
- # Generate text using the pipeline
47
  outputs = pipe(
48
  request.prompt,
49
  max_new_tokens=request.max_new_tokens,
@@ -54,6 +59,7 @@ async def generate_text(request: TextGenerationRequest):
54
  )
55
  return {"generated_text": outputs[0]["generated_text"]}
56
  except Exception as e:
 
57
  raise HTTPException(status_code=500, detail=str(e))
58
 
59
  # Add a root endpoint for health checks
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import torch
5
+ import logging
6
 
7
  # Initialize FastAPI app
8
  app = FastAPI()
9
 
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Load the Falcon-7B model with 8-bit quantization (if CUDA is available)
15
+ model_id = "tiiuae/falcon-7b-instruct"
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
 
18
  # Check if CUDA is available
 
20
  # Load the model with 8-bit quantization for GPU
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_id,
23
+ revision="main", # Pin to a specific revision
24
+ load_in_8bit=True,
25
+ device_map="auto"
26
  )
27
  else:
28
  # Fallback to CPU or full precision
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_id,
31
+ revision="main", # Pin to a specific revision
32
+ device_map="auto"
33
  )
34
 
35
  # Create a text generation pipeline
 
48
  @app.post("/generate-text")
49
  async def generate_text(request: TextGenerationRequest):
50
  try:
51
+ logger.info("Generating text...")
52
  outputs = pipe(
53
  request.prompt,
54
  max_new_tokens=request.max_new_tokens,
 
59
  )
60
  return {"generated_text": outputs[0]["generated_text"]}
61
  except Exception as e:
62
+ logger.error(f"Error generating text: {e}")
63
  raise HTTPException(status_code=500, detail=str(e))
64
 
65
  # Add a root endpoint for health checks