Buildwellai's picture
Update handler.py
7cd4d08 verified
raw
history blame
3.76 kB
import os
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
from peft import PeftModel
from dotenv import load_dotenv
load_dotenv()
class EndpointHandler:
def __init__(self, path=""):
"""
Initializes the model and tokenizer.
"""
max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", 2048))
max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 512))
self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
self.model_dir = os.getenv("MODEL_DIR", ".") # Should be "." for root
self.base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" # Base model!
print(f"MODEL_DIR: {self.model_dir}")
print(f"Files in model directory: {os.listdir(self.model_dir)}")
# Load Config (with trust_remote_code)
self.config = AutoConfig.from_pretrained(
self.base_model_name, token=self.hf_token, trust_remote_code=True
)
# Load Tokenizer (with trust_remote_code)
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name, token=self.hf_token, trust_remote_code=True
)
except Exception as e:
print(f"Error loading tokenizer: {e}")
raise
# Load Model and LoRA Adapter (with trust_remote_code)
try:
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
config=self.config,
torch_dtype=torch.bfloat16, # Use bfloat16
token=self.hf_token,
device_map="auto",
trust_remote_code=True, # Important for Qwen2
)
self.model = PeftModel.from_pretrained(base_model, self.model_dir)
except Exception as e:
print(f"Error loading model: {e}")
raise
self.prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
### Instruction:
You are BuildwellAI, an AI assistant specialized in UK building regulations and construction standards. You provide accurate, helpful information about building codes, construction best practices, and regulatory compliance in the UK.
Always be professional and precise in your responses..
### Question:
{}
### Response:
<think>{}"""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", None)
if inputs is None:
return [{"error": "No input provided. 'inputs' key missing."}]
if not isinstance(inputs, str):
return [{"error": "Invalid input type. 'inputs' must be a string."}]
input_text = self.prompt_style.format(inputs, "")
input_tokens = self.tokenizer([input_text], return_tensors="pt")
if torch.cuda.is_available():
input_tokens = input_tokens.to("cuda")
with torch.no_grad():
output_tokens = self.model.generate(
input_ids=input_tokens.input_ids,
attention_mask=input_tokens.attention_mask,
max_new_tokens=max_new_tokens,
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
generated_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
response = generated_text.split("### Response:")[-1].strip()
return [{"generated_text": response}]