Fred808 commited on
Commit
e71208c
·
verified ·
1 Parent(s): 64c0b0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -6,8 +6,8 @@ import torch
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
- # Load the Falcon-7B model with 8-bit quantization (if CUDA is available)
10
- model_id = "tiiuae/falcon-7b-instruct"
11
  tokenizer = AutoTokenizer.from_pretrained(model_id)
12
 
13
  # Check if CUDA is available
@@ -15,16 +15,16 @@ if torch.cuda.is_available():
15
  # Load the model with 8-bit quantization for GPU
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_id,
18
- load_in_8bit=True,
19
- device_map="auto",
20
- trust_remote_code=True
21
  )
22
  else:
23
  # Fallback to CPU or full precision
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
- device_map="auto",
27
- trust_remote_code=True
28
  )
29
 
30
  # Create a text generation pipeline
@@ -59,4 +59,4 @@ async def generate_text(request: TextGenerationRequest):
59
  # Add a root endpoint for health checks
60
  @app.get("/test")
61
  async def root():
62
- return {"message": "API is running!"}
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
+ # Load the latest Falcon-7B model with 8-bit quantization (if CUDA is available)
10
+ model_id = "tiiuae/falcon-7b-instruct" # Update this if there's a newer version
11
  tokenizer = AutoTokenizer.from_pretrained(model_id)
12
 
13
  # Check if CUDA is available
 
15
  # Load the model with 8-bit quantization for GPU
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_id,
18
+ load_in_8bit=True, # Use 8-bit quantization for GPU
19
+ device_map="auto", # Automatically map the model to available devices
20
+ trust_remote_code=True # Required for Falcon models
21
  )
22
  else:
23
  # Fallback to CPU or full precision
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
+ device_map="auto", # Automatically map the model to available devices
27
+ trust_remote_code=True # Required for Falcon models
28
  )
29
 
30
  # Create a text generation pipeline
 
59
  # Add a root endpoint for health checks
60
  @app.get("/test")
61
  async def root():
62
+ return {"message": "API is running!"}