ehagey commited on
Commit
b5ac215
·
verified ·
1 Parent(s): a48ccc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -246
app.py CHANGED
@@ -16,30 +16,12 @@ from anthropic import Anthropic
16
  import google.generativeai as genai
17
  import hmac
18
  import hashlib
19
- from uuid import uuid4
20
- from datetime import datetime
21
- from huggingface_hub import CommitScheduler, Repository
22
- from pathlib import Path
23
 
24
- load_dotenv()
25
-
26
- st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")
27
 
28
- WRITE_LOCK = threading.Lock()
29
- DATA_DIR = Path("data")
30
- DATA_DIR.mkdir(exist_ok=True)
31
- RESULTS_FILE = DATA_DIR / "results.csv"
32
 
33
 
34
- scheduler = CommitScheduler(
35
- repo_id=os.getenv("HF_REPO_ID"),
36
- repo_type="dataset",
37
- folder_path=DATA_DIR,
38
- path_in_repo="data",
39
- every=10,
40
- token=os.getenv("HF_TOKEN")
41
- )
42
 
 
43
  def initialize_session_state():
44
  if 'api_configured' not in st.session_state:
45
  st.session_state.api_configured = False
@@ -49,18 +31,10 @@ def initialize_session_state():
49
  st.session_state.openai_client = None
50
  if 'anthropic_client' not in st.session_state:
51
  st.session_state.anthropic_client = None
52
- if 'all_results' not in st.session_state:
53
- st.session_state.all_results = {}
54
- if 'detailed_model' not in st.session_state:
55
- st.session_state.detailed_model = None
56
- if 'detailed_dataset' not in st.session_state:
57
- st.session_state.detailed_dataset = None
58
- if 'last_evaluated_dataset' not in st.session_state:
59
- st.session_state.last_evaluated_dataset = None
60
-
61
- initialize_session_state()
62
 
63
  def setup_api_clients():
 
 
64
  with st.sidebar:
65
  st.title("API Configuration")
66
 
@@ -71,29 +45,22 @@ def setup_api_clients():
71
  password = st.text_input("Password", type="password")
72
 
73
  if st.button("Verify Credentials"):
74
- stored_username = os.getenv("STREAMLIT_USERNAME", "")
75
- stored_password = os.getenv("STREAMLIT_PASSWORD", "")
76
-
77
- if (hmac.compare_digest(username, stored_username) and
78
- hmac.compare_digest(password, stored_password)):
79
- try:
80
- st.session_state.togetherai_client = OpenAI(
81
- api_key=os.getenv('TOGETHERAI_API_KEY'),
82
- base_url="https://api.together.xyz/v1"
83
- )
84
- st.session_state.openai_client = OpenAI(
85
- api_key=os.getenv('OPENAI_API_KEY')
86
- )
87
- st.session_state.anthropic_client = Anthropic(
88
- api_key=os.getenv('ANTHROPIC_API_KEY')
89
- )
90
- genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
91
-
92
- st.session_state.api_configured = True
93
- st.success("Successfully configured the API clients with stored keys!")
94
- except Exception as e:
95
- st.error(f"Error initializing API clients: {str(e)}")
96
- st.session_state.api_configured = False
97
  else:
98
  st.error("Invalid credentials. Please try again or use your own API keys.")
99
  st.session_state.api_configured = False
@@ -124,11 +91,6 @@ def setup_api_clients():
124
  st.error(f"Error initializing API clients: {str(e)}")
125
  st.session_state.api_configured = False
126
 
127
- setup_api_clients()
128
-
129
-
130
- scheduler.start()
131
-
132
  MAX_CONCURRENT_CALLS = 5
133
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
134
 
@@ -154,7 +116,7 @@ def load_dataset_by_name(dataset_name, split="train"):
154
  }
155
  questions.append(question_dict)
156
 
157
- st.write(f"Loaded {len(questions)} single-select questions from `{dataset_name}`")
158
  return questions
159
 
160
  @retry(
@@ -162,6 +124,7 @@ def load_dataset_by_name(dataset_name, split="train"):
162
  stop=stop_after_attempt(5),
163
  retry=retry_if_exception_type(Exception)
164
  )
 
165
  def get_model_response(question, options, prompt_template, model_name, clients):
166
  with semaphore:
167
  try:
@@ -203,7 +166,6 @@ def get_model_response(question, options, prompt_template, model_name, clients):
203
  )
204
  response_text = chat_session.send_message(prompt).text
205
 
206
- # Extract JSON from response
207
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
208
  if not json_match:
209
  return f"Error: Invalid response format", response_text
@@ -219,14 +181,13 @@ def get_model_response(question, options, prompt_template, model_name, clients):
219
  except Exception as e:
220
  return f"Error: {str(e)}", str(e)
221
 
222
-
223
  def evaluate_response(model_response, correct_answer):
224
  if model_response.startswith("Error:"):
225
  return False
226
  is_correct = model_response.lower().strip() == correct_answer.lower().strip()
227
  return is_correct
228
 
229
- def process_single_evaluation(question, prompt_template, model_name, clients, last_evaluated_dataset):
230
  answer, response_text = get_model_response(
231
  question['question'],
232
  question['options'],
@@ -235,57 +196,29 @@ def process_single_evaluation(question, prompt_template, model_name, clients, la
235
  clients
236
  )
237
  is_correct = evaluate_response(answer, question['correct_answer'])
238
- result = {
239
- 'dataset': last_evaluated_dataset,
240
- 'model': model_name,
241
  'question': question['question'],
 
 
 
 
242
  'correct_answer': question['correct_answer'],
243
  'subject': question['subject_name'],
244
- 'options': ' | '.join(question['options']),
245
- 'model_response': answer,
246
  'is_correct': is_correct,
247
  'explanation': question['explanation'],
248
- 'timestamp': datetime.utcnow().isoformat()
249
  }
250
-
251
- with WRITE_LOCK:
252
- if RESULTS_FILE.exists():
253
- existing_df = pd.read_csv(RESULTS_FILE)
254
- updated_df = existing_df.append(result, ignore_index=True)
255
- else:
256
- updated_df = pd.DataFrame([result])
257
-
258
- updated_df.to_csv(RESULTS_FILE, index=False)
259
-
260
- return result
261
 
262
- def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback, clients, last_evaluated_dataset):
263
  results = []
264
  total_iterations = len(models_to_evaluate) * len(questions)
265
  current_iteration = 0
266
 
267
- if RESULTS_FILE.exists():
268
- existing_df = pd.read_csv(RESULTS_FILE)
269
- completed = set(zip(existing_df['model'], existing_df['question']))
270
- else:
271
- completed = set()
272
-
273
  with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor:
274
  future_to_params = {}
275
  for model_name in models_to_evaluate:
276
  for question in questions:
277
- if (model_name, question['question']) in completed:
278
- current_iteration += 1
279
- progress_callback(current_iteration, total_iterations)
280
- continue # Skip already completed evaluations
281
- future = executor.submit(
282
- process_single_evaluation,
283
- question,
284
- prompt_template,
285
- model_name,
286
- clients,
287
- last_evaluated_dataset
288
- )
289
  future_to_params[future] = (model_name, question)
290
 
291
  for future in as_completed(future_to_params):
@@ -295,42 +228,25 @@ def process_evaluations_concurrently(questions, prompt_template, models_to_evalu
295
  progress_callback(current_iteration, total_iterations)
296
 
297
  return results
298
-
299
-
300
  def main():
 
 
 
 
 
 
 
 
 
301
  if 'all_results' not in st.session_state:
302
  st.session_state.all_results = {}
 
 
 
 
 
303
  st.session_state.last_evaluated_dataset = None
304
- if RESULTS_FILE.exists():
305
- existing_df = pd.read_csv(RESULTS_FILE)
306
- all_results = {}
307
- for _, row in existing_df.iterrows():
308
- model = row['model']
309
- result = row.to_dict()
310
- if model not in all_results:
311
- all_results[model] = []
312
- all_results[model].append(result)
313
- st.session_state.all_results = all_results
314
- st.session_state.last_evaluated_dataset = existing_df['dataset'].iloc[-1]
315
- st.info(f"Loaded existing results from `{RESULTS_FILE}`.")
316
- else:
317
- st.session_state.all_results = {}
318
- st.session_state.last_evaluated_dataset = None
319
- st.info(f"No existing results found. Ready to start fresh.")
320
-
321
- with st.sidebar:
322
- if st.button("Reset Results"):
323
- if RESULTS_FILE.exists():
324
- try:
325
- RESULTS_FILE.unlink()
326
- st.session_state.all_results = {}
327
- st.session_state.last_evaluated_dataset = None
328
- st.success("Results have been reset.")
329
- except Exception as e:
330
- st.error(f"Error deleting file: {str(e)}")
331
- else:
332
- st.info("No results to reset.")
333
-
334
  col1, col2 = st.columns(2)
335
  with col1:
336
  selected_dataset = st.selectbox(
@@ -339,46 +255,41 @@ def main():
339
  help="Choose the dataset to evaluate on"
340
  )
341
  with col2:
342
- selected_models = st.multiselect(
343
  "Select Model(s)",
344
  options=list(MODELS.keys()),
345
  default=[list(MODELS.keys())[0]],
346
  help="Choose one or more models to evaluate."
347
  )
348
 
349
- models_to_evaluate = selected_models
350
 
351
  default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
352
  Question: {question}
353
-
354
  Options:
355
  {options}
356
-
357
  ## Output Format:
358
  Please provide your answer in JSON format that contains an "answer" field.
359
  You may include any additional fields in your JSON response that you find relevant, such as:
360
  - "choice reasoning": your detailed reasoning
361
  - "elimination reasoning": why you ruled out other options
362
-
363
  Example response format:
364
  {
365
  "answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)",
366
  "choice reasoning": "your detailed reasoning here",
367
  "elimination reasoning": "why you ruled out other options"
368
  }
369
-
370
  Important:
371
  - Only the "answer" field will be used for evaluation
372
  - Ensure your response is in valid JSON format'''
373
 
374
- # Customize Prompt Template
375
  col1, col2 = st.columns([2, 1])
376
  with col1:
377
  prompt_template = st.text_area(
378
  "Customize Prompt Template",
379
  default_prompt,
380
  height=400,
381
- help="Edit the prompt template before starting the evaluation."
382
  )
383
 
384
  with col2:
@@ -388,90 +299,67 @@ Important:
388
  - `{options}`: The multiple choice options
389
  """)
390
 
391
- # Load Dataset
392
- if st.session_state.api_configured:
393
- with st.spinner("Loading dataset..."):
394
- questions = load_dataset_by_name(selected_dataset)
395
- else:
396
- st.warning("Please configure the API keys in the sidebar to load datasets and proceed.")
397
- questions = []
398
-
399
- # Filter by Subject
400
- if questions:
401
- subjects = sorted(list(set(q['subject_name'] for q in questions)))
402
- selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
- if selected_subject != "All":
405
- questions = [q for q in questions if q['subject_name'] == selected_subject]
406
-
407
- # Number of Questions to Evaluate
408
- num_questions = st.number_input(
409
- "Number of questions to evaluate",
410
- min_value=1,
411
- max_value=len(questions),
412
- value=min(10, len(questions)),
413
- step=1
414
- )
415
 
416
- # Start Evaluation Button
417
- if st.button("Start Evaluation"):
418
- if not models_to_evaluate:
419
- st.error("Please select at least one model to evaluate.")
420
- else:
421
- with st.spinner("Starting evaluation..."):
422
- selected_questions = questions[:num_questions]
423
-
424
- clients = {
425
- "togetherai": st.session_state["togetherai_client"],
426
- "openai": st.session_state["openai_client"],
427
- "anthropic": st.session_state["anthropic_client"]
428
- }
429
-
430
- last_evaluated_dataset = st.session_state.last_evaluated_dataset if st.session_state.last_evaluated_dataset else selected_dataset
431
 
432
- progress_container = st.container()
433
- progress_bar = progress_container.progress(0)
434
- status_text = progress_container.empty()
435
-
436
- def update_progress(current, total):
437
- progress = current / total
438
- progress_bar.progress(progress)
439
- status_text.text(f"Progress: {current}/{total} evaluations completed")
440
-
441
- results = process_evaluations_concurrently(
442
- selected_questions,
443
- prompt_template,
444
- models_to_evaluate,
445
- update_progress,
446
- clients,
447
- last_evaluated_dataset
448
- )
449
-
450
- # Update Session State with New Results
451
- all_results = st.session_state.all_results.copy()
452
- for result in results:
453
- model = result.pop('model')
454
- if model not in all_results:
455
- all_results[model] = []
456
- all_results[model].append(result)
457
-
458
- st.session_state.all_results = all_results
459
- st.session_state.last_evaluated_dataset = selected_dataset
460
-
461
- # Set Default Detailed Model and Dataset if Not Set
462
- if st.session_state.detailed_model is None and all_results:
463
- st.session_state.detailed_model = list(all_results.keys())[0]
464
- if st.session_state.detailed_dataset is None:
465
- st.session_state.detailed_dataset = selected_dataset
466
-
467
- st.success("Evaluation completed!")
468
- st.experimental_rerun()
469
-
470
- # Display Evaluation Results
471
  if st.session_state.all_results:
472
  st.subheader("Evaluation Results")
 
473
  model_metrics = {}
474
-
475
  for model_name, results in st.session_state.all_results.items():
476
  df = pd.DataFrame(results)
477
  metrics = {
@@ -479,28 +367,29 @@ Important:
479
  }
480
  model_metrics[model_name] = metrics
481
 
482
- metrics_df = pd.DataFrame(model_metrics).T.reset_index().rename(columns={'index': 'Model'})
483
 
484
  st.subheader("Model Performance Comparison")
 
485
  accuracy_chart = alt.Chart(
486
- metrics_df
487
  ).mark_bar().encode(
488
- x=alt.X('Model:N', title=None),
489
- y=alt.Y('Accuracy:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
490
- color=alt.Color('Model:N', scale=alt.Scale(scheme='blues')),
491
- tooltip=['Model:N', 'Accuracy:Q']
492
  ).properties(
493
  height=300,
494
  title={
495
  "text": "Model Accuracy",
496
- "anchor": "middle",
497
- "fontSize": 20
 
498
  }
499
- ).interactive()
500
 
501
  st.altair_chart(accuracy_chart, use_container_width=True)
502
 
503
- # Display Detailed Results
504
  if st.session_state.all_results:
505
  st.subheader("Detailed Results")
506
 
@@ -524,12 +413,12 @@ Important:
524
  with col2:
525
  selected_dataset_details = st.selectbox(
526
  "Select dataset",
527
- options=[st.session_state.last_evaluated_dataset] if st.session_state.last_evaluated_dataset else [],
528
  key="dataset_select",
529
  on_change=update_dataset
530
  )
531
 
532
- if selected_model_details and selected_model_details in st.session_state.all_results:
533
  results = st.session_state.all_results[selected_model_details]
534
  df = pd.DataFrame(results)
535
  accuracy = df['is_correct'].mean()
@@ -540,16 +429,16 @@ Important:
540
  with st.expander(f"Question {idx + 1} - {result['subject']}"):
541
  st.write("**Question:**", result['question'])
542
  st.write("**Options:**")
543
- for i, opt in enumerate(result['options'].split(' | ')):
544
  st.write(f"{chr(65+i)}. {opt}")
545
 
546
  col1, col2 = st.columns(2)
547
  with col1:
548
- st.write("**Model Response:**")
549
- st.code(result.get('model_response', "N/A"))
550
  with col2:
551
- st.write("**Explanation:**")
552
- st.code(result.get('explanation', "N/A"))
553
 
554
  col1, col2 = st.columns(2)
555
  with col1:
@@ -561,23 +450,39 @@ Important:
561
  else:
562
  st.error("Incorrect")
563
 
564
- st.write("**Timestamp:**", result['timestamp'])
565
  else:
566
  st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
567
 
568
  st.markdown("---")
569
- st.subheader("Download Results")
570
- if RESULTS_FILE.exists():
571
- csv_data = RESULTS_FILE.read_text(encoding='utf-8')
572
- st.download_button(
573
- label="Download All Results as CSV",
574
- data=csv_data,
575
- file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv",
576
- mime="text/csv",
577
- key="download_all_results"
578
- )
579
- else:
580
- st.info("No data available to download.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  if __name__ == "__main__":
583
- main()
 
16
  import google.generativeai as genai
17
  import hmac
18
  import hashlib
 
 
 
 
19
 
 
 
 
20
 
 
 
 
 
21
 
22
 
 
 
 
 
 
 
 
 
23
 
24
+ load_dotenv()
25
  def initialize_session_state():
26
  if 'api_configured' not in st.session_state:
27
  st.session_state.api_configured = False
 
31
  st.session_state.openai_client = None
32
  if 'anthropic_client' not in st.session_state:
33
  st.session_state.anthropic_client = None
 
 
 
 
 
 
 
 
 
 
34
 
35
  def setup_api_clients():
36
+ initialize_session_state()
37
+
38
  with st.sidebar:
39
  st.title("API Configuration")
40
 
 
45
  password = st.text_input("Password", type="password")
46
 
47
  if st.button("Verify Credentials"):
48
+ if (hmac.compare_digest(username, os.environ.get("STREAMLIT_USERNAME", "")) and
49
+ hmac.compare_digest(password, os.environ.get("STREAMLIT_PASSWORD", ""))):
50
+ st.session_state.togetherai_client = OpenAI(
51
+ api_key=os.getenv('TOGETHERAI_API_KEY'),
52
+ base_url="https://api.together.xyz/v1"
53
+ )
54
+ st.session_state.openai_client = OpenAI(
55
+ api_key=os.getenv('OPENAI_API_KEY')
56
+ )
57
+ st.session_state.anthropic_client = Anthropic(
58
+ api_key=os.getenv('ANTHROPIC_API_KEY')
59
+ )
60
+ genai.configure(api_key=os.environ["GEMINI_API_KEY"])
61
+
62
+ st.session_state.api_configured = True
63
+ st.success("Successfully configured the API clients with stored keys!")
 
 
 
 
 
 
 
64
  else:
65
  st.error("Invalid credentials. Please try again or use your own API keys.")
66
  st.session_state.api_configured = False
 
91
  st.error(f"Error initializing API clients: {str(e)}")
92
  st.session_state.api_configured = False
93
 
 
 
 
 
 
94
  MAX_CONCURRENT_CALLS = 5
95
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
96
 
 
116
  }
117
  questions.append(question_dict)
118
 
119
+ st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}")
120
  return questions
121
 
122
  @retry(
 
124
  stop=stop_after_attempt(5),
125
  retry=retry_if_exception_type(Exception)
126
  )
127
+
128
  def get_model_response(question, options, prompt_template, model_name, clients):
129
  with semaphore:
130
  try:
 
166
  )
167
  response_text = chat_session.send_message(prompt).text
168
 
 
169
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
170
  if not json_match:
171
  return f"Error: Invalid response format", response_text
 
181
  except Exception as e:
182
  return f"Error: {str(e)}", str(e)
183
 
 
184
  def evaluate_response(model_response, correct_answer):
185
  if model_response.startswith("Error:"):
186
  return False
187
  is_correct = model_response.lower().strip() == correct_answer.lower().strip()
188
  return is_correct
189
 
190
+ def process_single_evaluation(question, prompt_template, model_name, clients):
191
  answer, response_text = get_model_response(
192
  question['question'],
193
  question['options'],
 
196
  clients
197
  )
198
  is_correct = evaluate_response(answer, question['correct_answer'])
199
+ return {
 
 
200
  'question': question['question'],
201
+ 'options': question['options'],
202
+ 'model_response': answer,
203
+ 'raw_llm_response': response_text,
204
+ 'prompt_sent': prompt_template,
205
  'correct_answer': question['correct_answer'],
206
  'subject': question['subject_name'],
 
 
207
  'is_correct': is_correct,
208
  'explanation': question['explanation'],
209
+ 'model_name': model_name
210
  }
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback, clients):
213
  results = []
214
  total_iterations = len(models_to_evaluate) * len(questions)
215
  current_iteration = 0
216
 
 
 
 
 
 
 
217
  with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor:
218
  future_to_params = {}
219
  for model_name in models_to_evaluate:
220
  for question in questions:
221
+ future = executor.submit(process_single_evaluation, question, prompt_template, model_name, clients)
 
 
 
 
 
 
 
 
 
 
 
222
  future_to_params[future] = (model_name, question)
223
 
224
  for future in as_completed(future_to_params):
 
228
  progress_callback(current_iteration, total_iterations)
229
 
230
  return results
231
+
 
232
  def main():
233
+ st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")
234
+
235
+ initialize_session_state()
236
+ setup_api_clients()
237
+
238
+ if not st.session_state.api_configured:
239
+ st.warning("Please configure API keys in the sidebar to proceed")
240
+ st.stop()
241
+
242
  if 'all_results' not in st.session_state:
243
  st.session_state.all_results = {}
244
+ if 'detailed_model' not in st.session_state:
245
+ st.session_state.detailed_model = None
246
+ if 'detailed_dataset' not in st.session_state:
247
+ st.session_state.detailed_dataset = None
248
+ if 'last_evaluated_dataset' not in st.session_state:
249
  st.session_state.last_evaluated_dataset = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  col1, col2 = st.columns(2)
251
  with col1:
252
  selected_dataset = st.selectbox(
 
255
  help="Choose the dataset to evaluate on"
256
  )
257
  with col2:
258
+ selected_model = st.multiselect(
259
  "Select Model(s)",
260
  options=list(MODELS.keys()),
261
  default=[list(MODELS.keys())[0]],
262
  help="Choose one or more models to evaluate."
263
  )
264
 
265
+ models_to_evaluate = selected_model
266
 
267
  default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
268
  Question: {question}
 
269
  Options:
270
  {options}
 
271
  ## Output Format:
272
  Please provide your answer in JSON format that contains an "answer" field.
273
  You may include any additional fields in your JSON response that you find relevant, such as:
274
  - "choice reasoning": your detailed reasoning
275
  - "elimination reasoning": why you ruled out other options
 
276
  Example response format:
277
  {
278
  "answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)",
279
  "choice reasoning": "your detailed reasoning here",
280
  "elimination reasoning": "why you ruled out other options"
281
  }
 
282
  Important:
283
  - Only the "answer" field will be used for evaluation
284
  - Ensure your response is in valid JSON format'''
285
 
 
286
  col1, col2 = st.columns([2, 1])
287
  with col1:
288
  prompt_template = st.text_area(
289
  "Customize Prompt Template",
290
  default_prompt,
291
  height=400,
292
+ help="The below prompt is editable. Please feel free to edit it before your run."
293
  )
294
 
295
  with col2:
 
299
  - `{options}`: The multiple choice options
300
  """)
301
 
302
+ with st.spinner("Loading dataset..."):
303
+ questions = load_dataset_by_name(selected_dataset)
304
+ subjects = sorted(list(set(q['subject_name'] for q in questions)))
305
+ selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
306
+
307
+ if selected_subject != "All":
308
+ questions = [q for q in questions if q['subject_name'] == selected_subject]
309
+
310
+ num_questions = st.number_input("Number of questions to evaluate", 1, len(questions))
311
+
312
+ if st.button("Start Evaluation"):
313
+ with st.spinner("Starting evaluation..."):
314
+ selected_questions = questions[:num_questions]
315
+
316
+ # Create a clients dictionary
317
+ clients = {
318
+ "togetherai": st.session_state["togetherai_client"],
319
+ "openai": st.session_state["openai_client"],
320
+ "anthropic": st.session_state["anthropic_client"]
321
+ }
322
+
323
+ progress_container = st.container()
324
+ progress_bar = progress_container.progress(0)
325
+ status_text = progress_container.empty()
326
+
327
+ def update_progress(current, total):
328
+ progress = current / total
329
+ progress_bar.progress(progress)
330
+ status_text.text(f"Progress: {current}/{total} evaluations completed")
331
+
332
+ results = process_evaluations_concurrently(
333
+ selected_questions,
334
+ prompt_template,
335
+ models_to_evaluate,
336
+ update_progress,
337
+ clients
338
+ )
339
 
340
+ all_results = {}
341
+ for result in results:
342
+ model = result.pop('model_name')
343
+ if model not in all_results:
344
+ all_results[model] = []
345
+ all_results[model].append(result)
346
+
347
+ st.session_state.all_results = all_results
348
+ st.session_state.last_evaluated_dataset = selected_dataset
 
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ if st.session_state.detailed_model is None and all_results:
352
+ st.session_state.detailed_model = list(all_results.keys())[0]
353
+ if st.session_state.detailed_dataset is None:
354
+ st.session_state.detailed_dataset = selected_dataset
355
+
356
+ st.success("Evaluation completed!")
357
+ st.rerun()
358
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  if st.session_state.all_results:
360
  st.subheader("Evaluation Results")
361
+
362
  model_metrics = {}
 
363
  for model_name, results in st.session_state.all_results.items():
364
  df = pd.DataFrame(results)
365
  metrics = {
 
367
  }
368
  model_metrics[model_name] = metrics
369
 
370
+ metrics_df = pd.DataFrame(model_metrics).T
371
 
372
  st.subheader("Model Performance Comparison")
373
+
374
  accuracy_chart = alt.Chart(
375
+ metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy'])
376
  ).mark_bar().encode(
377
+ x=alt.X('index:N', title=None, axis=None),
378
+ y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
379
+ color=alt.Color('index:N', scale=alt.Scale(scheme='blues')),
380
+ tooltip=['index:N', 'value:Q']
381
  ).properties(
382
  height=300,
383
  title={
384
  "text": "Model Accuracy",
385
+ "baseline": "bottom",
386
+ "orient": "bottom",
387
+ "dy": 20
388
  }
389
+ )
390
 
391
  st.altair_chart(accuracy_chart, use_container_width=True)
392
 
 
393
  if st.session_state.all_results:
394
  st.subheader("Detailed Results")
395
 
 
413
  with col2:
414
  selected_dataset_details = st.selectbox(
415
  "Select dataset",
416
+ options=[st.session_state.last_evaluated_dataset],
417
  key="dataset_select",
418
  on_change=update_dataset
419
  )
420
 
421
+ if selected_model_details in st.session_state.all_results:
422
  results = st.session_state.all_results[selected_model_details]
423
  df = pd.DataFrame(results)
424
  accuracy = df['is_correct'].mean()
 
429
  with st.expander(f"Question {idx + 1} - {result['subject']}"):
430
  st.write("**Question:**", result['question'])
431
  st.write("**Options:**")
432
+ for i, opt in enumerate(result['options']):
433
  st.write(f"{chr(65+i)}. {opt}")
434
 
435
  col1, col2 = st.columns(2)
436
  with col1:
437
+ st.write("**Prompt Used:**")
438
+ st.code(result['prompt_sent'])
439
  with col2:
440
+ st.write("**Raw Response:**")
441
+ st.code(result['raw_llm_response'])
442
 
443
  col1, col2 = st.columns(2)
444
  with col1:
 
450
  else:
451
  st.error("Incorrect")
452
 
453
+ st.write("**Explanation:**", result['explanation'])
454
  else:
455
  st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
456
 
457
  st.markdown("---")
458
+ all_data = []
459
+
460
+ for model_name, results in st.session_state.all_results.items():
461
+ for question_idx, result in enumerate(results):
462
+ row = {
463
+ 'dataset': st.session_state.last_evaluated_dataset,
464
+ 'model': model_name,
465
+ 'question': result['question'],
466
+ 'correct_answer': result['correct_answer'],
467
+ 'subject': result['subject'],
468
+ 'options': ' | '.join(result['options']),
469
+ 'model_response': result['model_response'],
470
+ 'is_correct': result['is_correct'],
471
+ 'explanation': result['explanation']
472
+ }
473
+ all_data.append(row)
474
+
475
+ complete_df = pd.DataFrame(all_data)
476
+
477
+ csv = complete_df.to_csv(index=False)
478
+
479
+ st.download_button(
480
+ label="Download All Results as CSV",
481
+ data=csv,
482
+ file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv",
483
+ mime="text/csv",
484
+ key="download_all_results"
485
+ )
486
 
487
  if __name__ == "__main__":
488
+ main()