Update handler.py
Browse files- handler.py +21 -24
handler.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
from typing import Dict, List, Any
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
4 |
import torch
|
5 |
-
from unsloth import FastLanguageModel
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
load_dotenv()
|
@@ -16,57 +16,55 @@ class EndpointHandler:
|
|
16 |
max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", 2048))
|
17 |
max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 512))
|
18 |
self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
|
19 |
-
self.model_dir = os.getenv("MODEL_DIR", ".")
|
20 |
|
21 |
-
print(f"MODEL_DIR: {self.model_dir}")
|
22 |
-
print(f"Files in model directory: {os.listdir(self.model_dir)}")
|
23 |
|
24 |
# --- 1. Load Config ---
|
25 |
-
# Load the configuration first
|
26 |
-
self.config = AutoConfig.from_pretrained(self.model_dir, token=self.hf_token)
|
27 |
|
28 |
# --- 2. Load Tokenizer ---
|
29 |
-
# Load the tokenizer. Handle tokenizer loading errors.
|
30 |
try:
|
31 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, token=self.hf_token)
|
32 |
except Exception as e:
|
33 |
print(f"Error loading tokenizer: {e}")
|
34 |
raise
|
35 |
|
36 |
# --- 3. Load Model ---
|
37 |
-
# Load the base model *and* apply the LoRA adapter. Handle model loading.
|
38 |
try:
|
39 |
-
# Load base model
|
40 |
self.model = AutoModelForCausalLM.from_pretrained(
|
41 |
-
self.config.base_model_name_or_path,
|
42 |
config=self.config,
|
43 |
-
torch_dtype=torch.bfloat16,
|
44 |
token=self.hf_token,
|
45 |
-
device_map="auto",
|
|
|
46 |
)
|
47 |
|
48 |
# Load and apply the LoRA adapter
|
49 |
-
self.model = FastLanguageModel.get_peft_model(self.model, self.model_dir)
|
50 |
-
FastLanguageModel.for_inference(self.model)
|
51 |
|
52 |
except Exception as e:
|
53 |
print(f"Error loading model: {e}")
|
54 |
raise
|
55 |
|
56 |
-
|
57 |
-
self.prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
|
58 |
-
Write a response that appropriately completes the request.
|
59 |
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
|
60 |
### Instruction:
|
61 |
-
You are BuildwellAI, an AI assistant specialized in UK building regulations and construction standards. You provide accurate, helpful information about building codes, construction best practices, and regulatory compliance in the UK.
|
62 |
-
Always be professional and precise in your responses.
|
|
|
63 |
### Question:
|
64 |
{}
|
65 |
### Response:
|
66 |
<think>{}"""
|
67 |
|
68 |
-
|
69 |
-
|
70 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
71 |
"""
|
72 |
Processes the input and generates a response.
|
@@ -84,8 +82,7 @@ Always be professional and precise in your responses.
|
|
84 |
if torch.cuda.is_available():
|
85 |
input_tokens = input_tokens.to("cuda")
|
86 |
|
87 |
-
|
88 |
-
with torch.no_grad(): # Ensure no gradient calculation
|
89 |
output_tokens = self.model.generate(
|
90 |
input_ids=input_tokens.input_ids,
|
91 |
attention_mask=input_tokens.attention_mask,
|
|
|
2 |
from typing import Dict, List, Any
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
4 |
import torch
|
5 |
+
from unsloth import FastLanguageModel
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
load_dotenv()
|
|
|
16 |
max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", 2048))
|
17 |
max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 512))
|
18 |
self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
|
19 |
+
self.model_dir = os.getenv("MODEL_DIR", ".")
|
20 |
|
21 |
+
print(f"MODEL_DIR: {self.model_dir}")
|
22 |
+
print(f"Files in model directory: {os.listdir(self.model_dir)}")
|
23 |
|
24 |
# --- 1. Load Config ---
|
25 |
+
# Load the configuration first, WITH trust_remote_code=True
|
26 |
+
self.config = AutoConfig.from_pretrained(self.model_dir, token=self.hf_token, trust_remote_code=True)
|
27 |
|
28 |
# --- 2. Load Tokenizer ---
|
|
|
29 |
try:
|
30 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, token=self.hf_token, trust_remote_code=True)
|
31 |
except Exception as e:
|
32 |
print(f"Error loading tokenizer: {e}")
|
33 |
raise
|
34 |
|
35 |
# --- 3. Load Model ---
|
|
|
36 |
try:
|
37 |
+
# Load base model, WITH trust_remote_code=True
|
38 |
self.model = AutoModelForCausalLM.from_pretrained(
|
39 |
+
self.config.base_model_name_or_path,
|
40 |
config=self.config,
|
41 |
+
torch_dtype=torch.bfloat16,
|
42 |
token=self.hf_token,
|
43 |
+
device_map="auto",
|
44 |
+
trust_remote_code=True, #CRUCIAL
|
45 |
)
|
46 |
|
47 |
# Load and apply the LoRA adapter
|
48 |
+
self.model = FastLanguageModel.get_peft_model(self.model, self.model_dir)
|
49 |
+
FastLanguageModel.for_inference(self.model)
|
50 |
|
51 |
except Exception as e:
|
52 |
print(f"Error loading model: {e}")
|
53 |
raise
|
54 |
|
55 |
+
# Define the prompt style
|
56 |
+
self.prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
|
57 |
+
Write a response that appropriately completes the request.
|
58 |
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
|
59 |
### Instruction:
|
60 |
+
You are BuildwellAI, an AI assistant specialized in UK building regulations and construction standards. You provide accurate, helpful information about building codes, construction best practices, and regulatory compliance in the UK.
|
61 |
+
Always be professional and precise in your responses.
|
62 |
+
|
63 |
### Question:
|
64 |
{}
|
65 |
### Response:
|
66 |
<think>{}"""
|
67 |
|
|
|
|
|
68 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
69 |
"""
|
70 |
Processes the input and generates a response.
|
|
|
82 |
if torch.cuda.is_available():
|
83 |
input_tokens = input_tokens.to("cuda")
|
84 |
|
85 |
+
with torch.no_grad():
|
|
|
86 |
output_tokens = self.model.generate(
|
87 |
input_ids=input_tokens.input_ids,
|
88 |
attention_mask=input_tokens.attention_mask,
|