|
import os |
|
import time |
|
import json |
|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
import torch |
|
from peft import PeftModel |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
""" |
|
Initializes the model and tokenizer. |
|
""" |
|
max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", 2048)) |
|
max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 512)) |
|
self.hf_token = os.getenv("HUGGINGFACE_TOKEN") |
|
self.model_dir = os.getenv("MODEL_DIR", ".") |
|
self.base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" |
|
|
|
print(f"MODEL_DIR: {self.model_dir}") |
|
print(f"Files in model directory (initial): {os.listdir(self.model_dir)}") |
|
|
|
|
|
config_path = os.path.join(self.model_dir, "adapter_config.json") |
|
max_wait_time = 60 |
|
wait_interval = 2 |
|
start_time = time.time() |
|
|
|
while not os.path.exists(config_path): |
|
print(f"Waiting for adapter_config.json to appear...") |
|
time.sleep(wait_interval) |
|
if time.time() - start_time > max_wait_time: |
|
raise FileNotFoundError( |
|
f"adapter_config.json not found after {max_wait_time} seconds." |
|
) |
|
print("adapter_config.json found!") |
|
|
|
|
|
try: |
|
with open(config_path, "r") as f: |
|
adapter_config = json.load(f) |
|
|
|
if "base_model_name_or_path" not in adapter_config or \ |
|
"task_type" not in adapter_config: |
|
raise ValueError("adapter_config.json is missing required keys.") |
|
if adapter_config["base_model_name_or_path"] != self.base_model_name: |
|
raise ValueError("adapter_config.json base_model_name_or_path mismatch.") |
|
if adapter_config["task_type"] != "CAUSAL_LM": |
|
raise ValueError("adapter_config.json task_type is incorrect.") |
|
print("adapter_config.json contents verified.") |
|
|
|
except (FileNotFoundError, json.JSONDecodeError, ValueError) as e: |
|
raise Exception(f"Error verifying adapter_config.json: {e}") |
|
|
|
print(f"Files in model directory (after wait): {os.listdir(self.model_dir)}") |
|
|
|
|
|
|
|
self.config = AutoConfig.from_pretrained( |
|
self.base_model_name, token=self.hf_token, trust_remote_code=True |
|
) |
|
|
|
|
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.base_model_name, token=self.hf_token, trust_remote_code=True |
|
) |
|
except Exception as e: |
|
print(f"Error loading tokenizer: {e}") |
|
raise |
|
|
|
|
|
try: |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
self.base_model_name, |
|
config=self.config, |
|
torch_dtype=torch.bfloat16, |
|
token=self.hf_token, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
self.model = PeftModel.from_pretrained(base_model, self.model_dir) |
|
|
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
|
|
self.prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context. |
|
Write a response that appropriately completes the request. |
|
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response. |
|
|
|
### Instruction: |
|
You are BuildwellAI, an AI assistant specialized in UK building regulations and construction standards. You provide accurate, helpful information about building codes, construction best practices, and regulatory compliance in the UK. |
|
Always be professional and precise in your responses.. |
|
|
|
### Question: |
|
{} |
|
|
|
### Response: |
|
<think>{}""" |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
inputs = data.pop("inputs", None) |
|
if inputs is None: |
|
return [{"error": "No input provided. 'inputs' key missing."}] |
|
if not isinstance(inputs, str): |
|
return [{"error": "Invalid input type. 'inputs' must be a string."}] |
|
|
|
input_text = self.prompt_style.format(inputs, "") |
|
input_tokens = self.tokenizer([input_text], return_tensors="pt") |
|
if torch.cuda.is_available(): |
|
input_tokens = input_tokens.to("cuda") |
|
|
|
with torch.no_grad(): |
|
output_tokens = self.model.generate( |
|
input_ids=input_tokens.input_ids, |
|
attention_mask=input_tokens.attention_mask, |
|
max_new_tokens=max_new_tokens, |
|
use_cache=True, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
) |
|
|
|
generated_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0] |
|
response = generated_text.split("### Response:")[-1].strip() |
|
return [{"generated_text": response}] |