khurrameycon commited on
Commit
f7442cb
·
verified ·
1 Parent(s): ec60ae3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -96,10 +96,9 @@
96
 
97
  from fastapi import FastAPI, HTTPException
98
  from pydantic import BaseModel
99
- from transformers import AutoModelForCausalLM, AutoTokenizer
100
  import torch
101
  from huggingface_hub import snapshot_download
102
- from safetensors.torch import load_file
103
 
104
  class ModelInput(BaseModel):
105
  prompt: str
@@ -120,26 +119,22 @@ try:
120
  trust_remote_code=True,
121
  device_map="auto"
122
  )
123
-
124
  # Load tokenizer from base model
125
  print("Loading tokenizer...")
126
  tokenizer = AutoTokenizer.from_pretrained(base_model_path)
127
-
128
  # Download adapter weights
129
  print("Downloading adapter weights...")
130
  adapter_path_local = snapshot_download(adapter_path)
131
-
132
- # Load the safetensors file
133
- print("Loading adapter weights...")
134
- state_dict = load_file(f"{adapter_path_local}/adapter_model.safetensors")
135
-
136
- # Load state dict into model
137
- model.load_state_dict(state_dict, strict=False)
138
-
139
- # Optional: Set the model to use the adapter
140
- # In case you are using adapters, you need to activate them
141
- model.set_active_adapters(adapter_path) # Activating the adapter
142
-
143
  print("Model and adapter loaded successfully!")
144
 
145
  except Exception as e:
@@ -153,7 +148,7 @@ def generate_response(model, tokenizer, instruction, max_new_tokens=128):
153
  input_text = tokenizer.apply_chat_template(
154
  messages, tokenize=False, add_generation_prompt=True
155
  )
156
-
157
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
158
  outputs = model.generate(
159
  inputs,
@@ -162,10 +157,10 @@ def generate_response(model, tokenizer, instruction, max_new_tokens=128):
162
  top_p=0.9,
163
  do_sample=True,
164
  )
165
-
166
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
167
  return response
168
-
169
  except Exception as e:
170
  raise ValueError(f"Error generating response: {e}")
171
 
@@ -179,7 +174,7 @@ async def generate_text(input: ModelInput):
179
  max_new_tokens=input.max_new_tokens
180
  )
181
  return {"generated_text": response}
182
-
183
  except Exception as e:
184
  raise HTTPException(status_code=500, detail=str(e))
185
 
 
96
 
97
  from fastapi import FastAPI, HTTPException
98
  from pydantic import BaseModel
99
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoAdapterModel
100
  import torch
101
  from huggingface_hub import snapshot_download
 
102
 
103
  class ModelInput(BaseModel):
104
  prompt: str
 
119
  trust_remote_code=True,
120
  device_map="auto"
121
  )
122
+
123
  # Load tokenizer from base model
124
  print("Loading tokenizer...")
125
  tokenizer = AutoTokenizer.from_pretrained(base_model_path)
126
+
127
  # Download adapter weights
128
  print("Downloading adapter weights...")
129
  adapter_path_local = snapshot_download(adapter_path)
130
+
131
+ # Load the adapter model
132
+ print("Loading adapter model...")
133
+ adapter_model = AutoAdapterModel.from_pretrained(adapter_path_local, from_pt=True)
134
+
135
+ # Combine the base model and adapter
136
+ model = model.with_adapter(adapter_model)
137
+
 
 
 
 
138
  print("Model and adapter loaded successfully!")
139
 
140
  except Exception as e:
 
148
  input_text = tokenizer.apply_chat_template(
149
  messages, tokenize=False, add_generation_prompt=True
150
  )
151
+
152
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
153
  outputs = model.generate(
154
  inputs,
 
157
  top_p=0.9,
158
  do_sample=True,
159
  )
160
+
161
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
  return response
163
+
164
  except Exception as e:
165
  raise ValueError(f"Error generating response: {e}")
166
 
 
174
  max_new_tokens=input.max_new_tokens
175
  )
176
  return {"generated_text": response}
177
+
178
  except Exception as e:
179
  raise HTTPException(status_code=500, detail=str(e))
180