File size: 5,341 Bytes
8e8bcf4
af0695e
 
8e8bcf4
 
 
af0695e
8e8bcf4
 
 
 
 
 
 
 
 
 
 
 
af0695e
 
8e8bcf4
e2569b2
af0695e
8e8bcf4
af0695e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cd4d08
 
 
8e8bcf4
af0695e
8e8bcf4
7cd4d08
 
 
8e8bcf4
 
 
 
af0695e
8e8bcf4
7cd4d08
45f725d
8e8bcf4
af0695e
8e8bcf4
e2569b2
af0695e
8e8bcf4
45f725d
8e8bcf4
 
 
 
 
7cd4d08
 
8e8bcf4
7cd4d08
8e8bcf4
7cd4d08
 
 
8e8bcf4
 
7cd4d08
8e8bcf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cd4d08
8e8bcf4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import time  # Import the time module
import json
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
from peft import PeftModel
from dotenv import load_dotenv

load_dotenv()

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initializes the model and tokenizer.
        """
        max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", 2048))
        max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 512))
        self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
        self.model_dir = os.getenv("MODEL_DIR", ".")
        self.base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"

        print(f"MODEL_DIR: {self.model_dir}")
        print(f"Files in model directory (initial): {os.listdir(self.model_dir)}")

        # --- Wait for adapter_config.json ---
        config_path = os.path.join(self.model_dir, "adapter_config.json")
        max_wait_time = 60  # Wait up to 60 seconds
        wait_interval = 2  # Check every 2 seconds
        start_time = time.time()

        while not os.path.exists(config_path):
            print(f"Waiting for adapter_config.json to appear...")
            time.sleep(wait_interval)
            if time.time() - start_time > max_wait_time:
                raise FileNotFoundError(
                    f"adapter_config.json not found after {max_wait_time} seconds."
                )
        print("adapter_config.json found!")

        # --- Verify adapter_config.json contents ---
        try:
            with open(config_path, "r") as f:
                adapter_config = json.load(f)
                # Check for essential keys
                if "base_model_name_or_path" not in adapter_config or \
                   "task_type" not in adapter_config:
                    raise ValueError("adapter_config.json is missing required keys.")
                if adapter_config["base_model_name_or_path"] != self.base_model_name:
                    raise ValueError("adapter_config.json base_model_name_or_path mismatch.")
                if adapter_config["task_type"] != "CAUSAL_LM":
                    raise ValueError("adapter_config.json task_type is incorrect.")
                print("adapter_config.json contents verified.")

        except (FileNotFoundError, json.JSONDecodeError, ValueError) as e:
            raise Exception(f"Error verifying adapter_config.json: {e}")

        print(f"Files in model directory (after wait): {os.listdir(self.model_dir)}")


        # Load Config
        self.config = AutoConfig.from_pretrained(
            self.base_model_name, token=self.hf_token, trust_remote_code=True
        )

        # Load Tokenizer
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.base_model_name, token=self.hf_token, trust_remote_code=True
            )
        except Exception as e:
            print(f"Error loading tokenizer: {e}")
            raise

        # Load Model
        try:
            base_model = AutoModelForCausalLM.from_pretrained(
                self.base_model_name,
                config=self.config,
                torch_dtype=torch.bfloat16,
                token=self.hf_token,
                device_map="auto",
                trust_remote_code=True,
            )
            self.model = PeftModel.from_pretrained(base_model, self.model_dir)

        except Exception as e:
            print(f"Error loading model: {e}")
            raise

        self.prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.

### Instruction:
You are BuildwellAI, an AI assistant specialized in UK building regulations and construction standards. You provide accurate, helpful information about building codes, construction best practices, and regulatory compliance in the UK.
Always be professional and precise in your responses..

### Question:
{}

### Response:
<think>{}"""

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.pop("inputs", None)
        if inputs is None:
            return [{"error": "No input provided. 'inputs' key missing."}]
        if not isinstance(inputs, str):
            return [{"error": "Invalid input type. 'inputs' must be a string."}]

        input_text = self.prompt_style.format(inputs, "")
        input_tokens = self.tokenizer([input_text], return_tensors="pt")
        if torch.cuda.is_available():
            input_tokens = input_tokens.to("cuda")

        with torch.no_grad():
            output_tokens = self.model.generate(
                input_ids=input_tokens.input_ids,
                attention_mask=input_tokens.attention_mask,
                max_new_tokens=max_new_tokens,
                use_cache=True,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        generated_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
        response = generated_text.split("### Response:")[-1].strip()
        return [{"generated_text": response}]