amiguel commited on
Commit
e65b516
Β·
verified Β·
1 Parent(s): fb280d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -33
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
3
  from huggingface_hub import login
4
  from threading import Thread
5
  import PyPDF2
@@ -24,8 +24,12 @@ st.set_page_config(
24
  )
25
 
26
  # Model names
27
- BASE_MODEL_NAME ="HuggingFaceTB/SmolLM2-360M" #"HuggingFaceTB/SmolLM2-1.7B-Instruct" #"google/flan-t5-base" # Base FLAN-T5 model
28
- PEFT_ADAPTER_NAME ="amiguel/enterpriseFTmodel" #"amiguel/cerebrasFTdeepseek" #"amiguel/classFinetuned_deepseek" # PEFT adapter
 
 
 
 
29
 
30
  # Title with rocket emojis
31
  st.title("πŸš€ WizNerd Insp πŸš€")
@@ -40,6 +44,10 @@ with st.sidebar:
40
  hf_token = st.text_input("Hugging Face Token", type="password",
41
  help="Get your token from https://huggingface.co/settings/tokens")
42
 
 
 
 
 
43
  st.header("Upload Documents πŸ“‚")
44
  uploaded_file = st.file_uploader(
45
  "Choose a PDF or XLSX file",
@@ -70,7 +78,7 @@ def process_file(uploaded_file):
70
 
71
  # Model loading function
72
  @st.cache_resource
73
- def load_model(hf_token):
74
  try:
75
  if not hf_token:
76
  st.error("πŸ” Authentication required! Please provide a Hugging Face token.")
@@ -78,37 +86,42 @@ def load_model(hf_token):
78
 
79
  login(token=hf_token)
80
 
81
- # Load base FLAN-T5 model
82
- peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(
83
- BASE_MODEL_NAME,
84
- torch_dtype=torch.bfloat16,
85
- device_map="auto",
86
- token=hf_token
87
- )
88
-
89
- # Load PEFT adapter and merge with base model
90
- peft_model = PeftModel.from_pretrained(
91
- peft_model_base,
92
- PEFT_ADAPTER_NAME,
93
- torch_dtype=torch.bfloat16,
94
- is_trainable=False, # Set to False for inference
95
- token=hf_token
96
- )
97
-
98
  # Load tokenizer
99
- tokenizer = AutoTokenizer.from_pretrained(
100
- BASE_MODEL_NAME,
101
- token=hf_token
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- return peft_model, tokenizer
105
 
106
  except Exception as e:
107
  st.error(f"πŸ€– Model loading failed: {str(e)}")
108
  return None
109
 
110
  # Generation function with KV caching
111
- def generate_with_kv_cache(prompt, file_context, use_cache=True):
112
  full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
113
 
114
  streamer = TextIteratorStreamer(
@@ -120,7 +133,8 @@ def generate_with_kv_cache(prompt, file_context, use_cache=True):
120
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
121
 
122
  generation_kwargs = {
123
- **inputs,
 
124
  "max_new_tokens": 1024,
125
  "temperature": 0.7,
126
  "top_p": 0.9,
@@ -149,14 +163,15 @@ if prompt := st.chat_input("Ask your inspection question..."):
149
  st.error("πŸ”‘ Authentication required!")
150
  st.stop()
151
 
152
- # Load model if not already loaded
153
- if "model" not in st.session_state:
154
- model_data = load_model(hf_token)
155
  if model_data is None:
156
  st.error("Failed to load model. Please check your token and try again.")
157
  st.stop()
158
 
159
  st.session_state.model, st.session_state.tokenizer = model_data
 
160
 
161
  model = st.session_state.model
162
  tokenizer = st.session_state.tokenizer
@@ -174,7 +189,7 @@ if prompt := st.chat_input("Ask your inspection question..."):
174
  try:
175
  with st.chat_message("assistant", avatar=BOT_AVATAR):
176
  start_time = time.time()
177
- streamer = generate_with_kv_cache(prompt, file_context, use_cache=True)
178
 
179
  response_container = st.empty()
180
  full_response = ""
@@ -209,4 +224,4 @@ if prompt := st.chat_input("Ask your inspection question..."):
209
  except Exception as e:
210
  st.error(f"⚑ Generation error: {str(e)}")
211
  else:
212
- st.error("πŸ€– Model not loaded!")
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from huggingface_hub import login
4
  from threading import Thread
5
  import PyPDF2
 
24
  )
25
 
26
  # Model names
27
+ BASE_MODEL_NAME = "HuggingFaceTB/SmolLM2-360M"
28
+ MODEL_OPTIONS = {
29
+ "Full Fine-Tuned": "amiguel/SmolLM2-360M-concise-reasoning",
30
+ "LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
31
+ "QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora" # Hypothetical, adjust if needed
32
+ }
33
 
34
  # Title with rocket emojis
35
  st.title("πŸš€ WizNerd Insp πŸš€")
 
44
  hf_token = st.text_input("Hugging Face Token", type="password",
45
  help="Get your token from https://huggingface.co/settings/tokens")
46
 
47
+ st.header("Model Selection πŸ€–")
48
+ model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
49
+ selected_model = MODEL_OPTIONS[model_type]
50
+
51
  st.header("Upload Documents πŸ“‚")
52
  uploaded_file = st.file_uploader(
53
  "Choose a PDF or XLSX file",
 
78
 
79
  # Model loading function
80
  @st.cache_resource
81
+ def load_model(hf_token, model_type, selected_model):
82
  try:
83
  if not hf_token:
84
  st.error("πŸ” Authentication required! Please provide a Hugging Face token.")
 
86
 
87
  login(token=hf_token)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # Load tokenizer
90
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token)
91
+
92
+ # Load model based on type
93
+ if model_type == "Full Fine-Tuned":
94
+ # Load full fine-tuned model directly
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ selected_model,
97
+ torch_dtype=torch.bfloat16,
98
+ device_map="auto",
99
+ token=hf_token
100
+ )
101
+ else:
102
+ # Load base model and apply PEFT adapter
103
+ base_model = AutoModelForCausalLM.from_pretrained(
104
+ BASE_MODEL_NAME,
105
+ torch_dtype=torch.bfloat16,
106
+ device_map="auto",
107
+ token=hf_token
108
+ )
109
+ model = PeftModel.from_pretrained(
110
+ base_model,
111
+ selected_model,
112
+ torch_dtype=torch.bfloat16,
113
+ is_trainable=False, # Inference mode
114
+ token=hf_token
115
+ )
116
 
117
+ return model, tokenizer
118
 
119
  except Exception as e:
120
  st.error(f"πŸ€– Model loading failed: {str(e)}")
121
  return None
122
 
123
  # Generation function with KV caching
124
+ def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True):
125
  full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
126
 
127
  streamer = TextIteratorStreamer(
 
133
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
134
 
135
  generation_kwargs = {
136
+ "input_ids": inputs["input_ids"],
137
+ "attention_mask": inputs["attention_mask"],
138
  "max_new_tokens": 1024,
139
  "temperature": 0.7,
140
  "top_p": 0.9,
 
163
  st.error("πŸ”‘ Authentication required!")
164
  st.stop()
165
 
166
+ # Load model if not already loaded or if model type changed
167
+ if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
168
+ model_data = load_model(hf_token, model_type, selected_model)
169
  if model_data is None:
170
  st.error("Failed to load model. Please check your token and try again.")
171
  st.stop()
172
 
173
  st.session_state.model, st.session_state.tokenizer = model_data
174
+ st.session_state.model_type = model_type
175
 
176
  model = st.session_state.model
177
  tokenizer = st.session_state.tokenizer
 
189
  try:
190
  with st.chat_message("assistant", avatar=BOT_AVATAR):
191
  start_time = time.time()
192
+ streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
193
 
194
  response_container = st.empty()
195
  full_response = ""
 
224
  except Exception as e:
225
  st.error(f"⚑ Generation error: {str(e)}")
226
  else:
227
+ st.error("πŸ€– Model not loaded!")