Buildwellai commited on
Commit
e2569b2
·
verified ·
1 Parent(s): 9e2dc92

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -24
handler.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from typing import Dict, List, Any
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
4
  import torch
5
- from unsloth import FastLanguageModel # Import Unsloth
6
  from dotenv import load_dotenv
7
 
8
  load_dotenv()
@@ -16,57 +16,55 @@ class EndpointHandler:
16
  max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", 2048))
17
  max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 512))
18
  self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
19
- self.model_dir = os.getenv("MODEL_DIR", ".") # Ensure this is set correctly
20
 
21
- print(f"MODEL_DIR: {self.model_dir}") # Debug print
22
- print(f"Files in model directory: {os.listdir(self.model_dir)}") # VERY IMPORTANT
23
 
24
  # --- 1. Load Config ---
25
- # Load the configuration first to determine the base model type
26
- self.config = AutoConfig.from_pretrained(self.model_dir, token=self.hf_token)
27
 
28
  # --- 2. Load Tokenizer ---
29
- # Load the tokenizer. Handle tokenizer loading errors.
30
  try:
31
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, token=self.hf_token)
32
  except Exception as e:
33
  print(f"Error loading tokenizer: {e}")
34
  raise
35
 
36
  # --- 3. Load Model ---
37
- # Load the base model *and* apply the LoRA adapter. Handle model loading.
38
  try:
39
- # Load base model (using config to determine correct class)
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
- self.config.base_model_name_or_path, # Use base model from config!
42
  config=self.config,
43
- torch_dtype=torch.bfloat16, # Use bfloat16 if available
44
  token=self.hf_token,
45
- device_map="auto", # Let transformers handle device placement
 
46
  )
47
 
48
  # Load and apply the LoRA adapter
49
- self.model = FastLanguageModel.get_peft_model(self.model, self.model_dir) #model_dir contains lora weights
50
- FastLanguageModel.for_inference(self.model) #Unsloth speed up
51
 
52
  except Exception as e:
53
  print(f"Error loading model: {e}")
54
  raise
55
 
56
- # Define the prompt style
57
- self.prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
58
- Write a response that appropriately completes the request.
59
  Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
60
  ### Instruction:
61
- You are BuildwellAI, an AI assistant specialized in UK building regulations and construction standards. You provide accurate, helpful information about building codes, construction best practices, and regulatory compliance in the UK.
62
- Always be professional and precise in your responses.
 
63
  ### Question:
64
  {}
65
  ### Response:
66
  <think>{}"""
67
 
68
-
69
-
70
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
71
  """
72
  Processes the input and generates a response.
@@ -84,8 +82,7 @@ Always be professional and precise in your responses.
84
  if torch.cuda.is_available():
85
  input_tokens = input_tokens.to("cuda")
86
 
87
-
88
- with torch.no_grad(): # Ensure no gradient calculation
89
  output_tokens = self.model.generate(
90
  input_ids=input_tokens.input_ids,
91
  attention_mask=input_tokens.attention_mask,
 
2
  from typing import Dict, List, Any
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
4
  import torch
5
+ from unsloth import FastLanguageModel
6
  from dotenv import load_dotenv
7
 
8
  load_dotenv()
 
16
  max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", 2048))
17
  max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 512))
18
  self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
19
+ self.model_dir = os.getenv("MODEL_DIR", ".")
20
 
21
+ print(f"MODEL_DIR: {self.model_dir}")
22
+ print(f"Files in model directory: {os.listdir(self.model_dir)}")
23
 
24
  # --- 1. Load Config ---
25
+ # Load the configuration first, WITH trust_remote_code=True
26
+ self.config = AutoConfig.from_pretrained(self.model_dir, token=self.hf_token, trust_remote_code=True)
27
 
28
  # --- 2. Load Tokenizer ---
 
29
  try:
30
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, token=self.hf_token, trust_remote_code=True)
31
  except Exception as e:
32
  print(f"Error loading tokenizer: {e}")
33
  raise
34
 
35
  # --- 3. Load Model ---
 
36
  try:
37
+ # Load base model, WITH trust_remote_code=True
38
  self.model = AutoModelForCausalLM.from_pretrained(
39
+ self.config.base_model_name_or_path,
40
  config=self.config,
41
+ torch_dtype=torch.bfloat16,
42
  token=self.hf_token,
43
+ device_map="auto",
44
+ trust_remote_code=True, #CRUCIAL
45
  )
46
 
47
  # Load and apply the LoRA adapter
48
+ self.model = FastLanguageModel.get_peft_model(self.model, self.model_dir)
49
+ FastLanguageModel.for_inference(self.model)
50
 
51
  except Exception as e:
52
  print(f"Error loading model: {e}")
53
  raise
54
 
55
+ # Define the prompt style
56
+ self.prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
57
+ Write a response that appropriately completes the request.
58
  Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
59
  ### Instruction:
60
+ You are BuildwellAI, an AI assistant specialized in UK building regulations and construction standards. You provide accurate, helpful information about building codes, construction best practices, and regulatory compliance in the UK.
61
+ Always be professional and precise in your responses.
62
+
63
  ### Question:
64
  {}
65
  ### Response:
66
  <think>{}"""
67
 
 
 
68
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
69
  """
70
  Processes the input and generates a response.
 
82
  if torch.cuda.is_available():
83
  input_tokens = input_tokens.to("cuda")
84
 
85
+ with torch.no_grad():
 
86
  output_tokens = self.model.generate(
87
  input_ids=input_tokens.input_ids,
88
  attention_mask=input_tokens.attention_mask,