Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -96,10 +96,9 @@
|
|
96 |
|
97 |
from fastapi import FastAPI, HTTPException
|
98 |
from pydantic import BaseModel
|
99 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
100 |
import torch
|
101 |
from huggingface_hub import snapshot_download
|
102 |
-
from safetensors.torch import load_file
|
103 |
|
104 |
class ModelInput(BaseModel):
|
105 |
prompt: str
|
@@ -120,26 +119,22 @@ try:
|
|
120 |
trust_remote_code=True,
|
121 |
device_map="auto"
|
122 |
)
|
123 |
-
|
124 |
# Load tokenizer from base model
|
125 |
print("Loading tokenizer...")
|
126 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
127 |
-
|
128 |
# Download adapter weights
|
129 |
print("Downloading adapter weights...")
|
130 |
adapter_path_local = snapshot_download(adapter_path)
|
131 |
-
|
132 |
-
# Load the
|
133 |
-
print("Loading adapter
|
134 |
-
|
135 |
-
|
136 |
-
#
|
137 |
-
model.
|
138 |
-
|
139 |
-
# Optional: Set the model to use the adapter
|
140 |
-
# In case you are using adapters, you need to activate them
|
141 |
-
model.set_active_adapters(adapter_path) # Activating the adapter
|
142 |
-
|
143 |
print("Model and adapter loaded successfully!")
|
144 |
|
145 |
except Exception as e:
|
@@ -153,7 +148,7 @@ def generate_response(model, tokenizer, instruction, max_new_tokens=128):
|
|
153 |
input_text = tokenizer.apply_chat_template(
|
154 |
messages, tokenize=False, add_generation_prompt=True
|
155 |
)
|
156 |
-
|
157 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
|
158 |
outputs = model.generate(
|
159 |
inputs,
|
@@ -162,10 +157,10 @@ def generate_response(model, tokenizer, instruction, max_new_tokens=128):
|
|
162 |
top_p=0.9,
|
163 |
do_sample=True,
|
164 |
)
|
165 |
-
|
166 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
167 |
return response
|
168 |
-
|
169 |
except Exception as e:
|
170 |
raise ValueError(f"Error generating response: {e}")
|
171 |
|
@@ -179,7 +174,7 @@ async def generate_text(input: ModelInput):
|
|
179 |
max_new_tokens=input.max_new_tokens
|
180 |
)
|
181 |
return {"generated_text": response}
|
182 |
-
|
183 |
except Exception as e:
|
184 |
raise HTTPException(status_code=500, detail=str(e))
|
185 |
|
|
|
96 |
|
97 |
from fastapi import FastAPI, HTTPException
|
98 |
from pydantic import BaseModel
|
99 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoAdapterModel
|
100 |
import torch
|
101 |
from huggingface_hub import snapshot_download
|
|
|
102 |
|
103 |
class ModelInput(BaseModel):
|
104 |
prompt: str
|
|
|
119 |
trust_remote_code=True,
|
120 |
device_map="auto"
|
121 |
)
|
122 |
+
|
123 |
# Load tokenizer from base model
|
124 |
print("Loading tokenizer...")
|
125 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
126 |
+
|
127 |
# Download adapter weights
|
128 |
print("Downloading adapter weights...")
|
129 |
adapter_path_local = snapshot_download(adapter_path)
|
130 |
+
|
131 |
+
# Load the adapter model
|
132 |
+
print("Loading adapter model...")
|
133 |
+
adapter_model = AutoAdapterModel.from_pretrained(adapter_path_local, from_pt=True)
|
134 |
+
|
135 |
+
# Combine the base model and adapter
|
136 |
+
model = model.with_adapter(adapter_model)
|
137 |
+
|
|
|
|
|
|
|
|
|
138 |
print("Model and adapter loaded successfully!")
|
139 |
|
140 |
except Exception as e:
|
|
|
148 |
input_text = tokenizer.apply_chat_template(
|
149 |
messages, tokenize=False, add_generation_prompt=True
|
150 |
)
|
151 |
+
|
152 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
|
153 |
outputs = model.generate(
|
154 |
inputs,
|
|
|
157 |
top_p=0.9,
|
158 |
do_sample=True,
|
159 |
)
|
160 |
+
|
161 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
162 |
return response
|
163 |
+
|
164 |
except Exception as e:
|
165 |
raise ValueError(f"Error generating response: {e}")
|
166 |
|
|
|
174 |
max_new_tokens=input.max_new_tokens
|
175 |
)
|
176 |
return {"generated_text": response}
|
177 |
+
|
178 |
except Exception as e:
|
179 |
raise HTTPException(status_code=500, detail=str(e))
|
180 |
|