Fred808 commited on
Commit
11ba705
·
verified ·
1 Parent(s): 2e59f2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -11,13 +11,25 @@ app = FastAPI()
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
- # Load the Google Gemma 2B model and tokenizer
15
- model_id = "google/gemma-2b" # Use Google Gemma 2B
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
 
 
18
 
19
  # Create a text generation pipeline
20
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device="cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
21
 
22
  # Define request body schema
23
  class TextGenerationRequest(BaseModel):
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Load the Google Gemma 7B model and tokenizer
15
+ model_id = "google/gemma-7b" # Use Google Gemma 7B
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+
18
+ # Load the model with 4-bit quantization to reduce VRAM usage
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_id,
21
+ torch_dtype=torch.float16, # Use half-precision for faster inference
22
+ device_map="auto", # Automatically offload to available GPUs
23
+ load_in_4bit=True # Enable 4-bit quantization
24
+ )
25
 
26
  # Create a text generation pipeline
27
+ pipe = pipeline(
28
+ "text-generation",
29
+ model=model,
30
+ tokenizer=tokenizer,
31
+ device="cuda" if torch.cuda.is_available() else "cpu"
32
+ )
33
 
34
  # Define request body schema
35
  class TextGenerationRequest(BaseModel):