Buildwellai's picture
Create handler.py
8e8bcf4 verified
raw
history blame
4.33 kB
import os
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
from unsloth import FastLanguageModel # Import Unsloth
from dotenv import load_dotenv
load_dotenv()
class EndpointHandler:
def __init__(self, path=""):
"""
Initializes the model and tokenizer.
"""
# Key settings (from environment variables, with defaults)
max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", 2048))
max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 512))
self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
self.model_dir = os.getenv("MODEL_DIR", ".") # Ensure this is set correctly
print(f"MODEL_DIR: {self.model_dir}") # Debug print
print(f"Files in model directory: {os.listdir(self.model_dir)}") # VERY IMPORTANT
# --- 1. Load Config ---
# Load the configuration first to determine the base model type
self.config = AutoConfig.from_pretrained(self.model_dir, token=self.hf_token)
# --- 2. Load Tokenizer ---
# Load the tokenizer. Handle tokenizer loading errors.
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, token=self.hf_token)
except Exception as e:
print(f"Error loading tokenizer: {e}")
raise
# --- 3. Load Model ---
# Load the base model *and* apply the LoRA adapter. Handle model loading.
try:
# Load base model (using config to determine correct class)
self.model = AutoModelForCausalLM.from_pretrained(
self.config.base_model_name_or_path, # Use base model from config!
config=self.config,
torch_dtype=torch.bfloat16, # Use bfloat16 if available
token=self.hf_token,
device_map="auto", # Let transformers handle device placement
)
# Load and apply the LoRA adapter
self.model = FastLanguageModel.get_peft_model(self.model, self.model_dir) #model_dir contains lora weights
FastLanguageModel.for_inference(self.model) #Unsloth speed up
except Exception as e:
print(f"Error loading model: {e}")
raise
# Define the prompt style
self.prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
### Instruction:
You are BuildwellAI, an AI assistant specialized in UK building regulations and construction standards. You provide accurate, helpful information about building codes, construction best practices, and regulatory compliance in the UK.
Always be professional and precise in your responses.
### Question:
{}
### Response:
<think>{}"""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Processes the input and generates a response.
"""
inputs = data.pop("inputs", None)
if inputs is None:
return [{"error": "No input provided. 'inputs' key missing."}]
if not isinstance(inputs, str):
return [{"error": "Invalid input type. 'inputs' must be a string."}]
input_text = self.prompt_style.format(inputs, "")
# Tokenize and move to CUDA (if available)
input_tokens = self.tokenizer([input_text], return_tensors="pt")
if torch.cuda.is_available():
input_tokens = input_tokens.to("cuda")
with torch.no_grad(): # Ensure no gradient calculation
output_tokens = self.model.generate(
input_ids=input_tokens.input_ids,
attention_mask=input_tokens.attention_mask,
max_new_tokens=max_new_tokens,
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
generated_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
response = generated_text.split("### Response:")[-1].strip()
return [{"generated_text": response}]