ehagey commited on
Commit
ceec99c
·
verified ·
1 Parent(s): b7aa921

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -123
app.py CHANGED
@@ -18,22 +18,27 @@ import hmac
18
  import hashlib
19
  from uuid import uuid4
20
  from datetime import datetime
 
 
21
 
22
  load_dotenv()
23
 
24
  st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")
25
 
26
-
27
  WRITE_LOCK = threading.Lock()
28
- DATA_DIR = "data"
29
- RESULTS_FILE = os.path.join(DATA_DIR, "results.csv")
30
-
31
- if not os.path.exists(DATA_DIR):
32
- os.makedirs(DATA_DIR)
33
- st.success(f"Created `{DATA_DIR}` directory.")
34
- else:
35
- st.info(f"`{DATA_DIR}` directory already exists.")
36
-
 
 
 
 
37
 
38
  def initialize_session_state():
39
  if 'api_configured' not in st.session_state:
@@ -54,8 +59,8 @@ def initialize_session_state():
54
  st.session_state.last_evaluated_dataset = None
55
 
56
  initialize_session_state()
57
- def setup_api_clients():
58
 
 
59
  with st.sidebar:
60
  st.title("API Configuration")
61
 
@@ -66,8 +71,11 @@ def setup_api_clients():
66
  password = st.text_input("Password", type="password")
67
 
68
  if st.button("Verify Credentials"):
69
- if (hmac.compare_digest(username, os.environ.get("STREAMLIT_USERNAME", "")) and
70
- hmac.compare_digest(password, os.environ.get("STREAMLIT_PASSWORD", ""))):
 
 
 
71
  try:
72
  st.session_state.togetherai_client = OpenAI(
73
  api_key=os.getenv('TOGETHERAI_API_KEY'),
@@ -79,7 +87,7 @@ def setup_api_clients():
79
  st.session_state.anthropic_client = Anthropic(
80
  api_key=os.getenv('ANTHROPIC_API_KEY')
81
  )
82
- genai.configure(api_key=os.environ["GEMINI_API_KEY"])
83
 
84
  st.session_state.api_configured = True
85
  st.success("Successfully configured the API clients with stored keys!")
@@ -117,6 +125,10 @@ def setup_api_clients():
117
  st.session_state.api_configured = False
118
 
119
  setup_api_clients()
 
 
 
 
120
  MAX_CONCURRENT_CALLS = 5
121
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
122
 
@@ -191,6 +203,7 @@ def get_model_response(question, options, prompt_template, model_name, clients):
191
  )
192
  response_text = chat_session.send_message(prompt).text
193
 
 
194
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
195
  if not json_match:
196
  return f"Error: Invalid response format", response_text
@@ -206,6 +219,7 @@ def get_model_response(question, options, prompt_template, model_name, clients):
206
  except Exception as e:
207
  return f"Error: {str(e)}", str(e)
208
 
 
209
  def evaluate_response(model_response, correct_answer):
210
  if model_response.startswith("Error:"):
211
  return False
@@ -233,10 +247,15 @@ def process_single_evaluation(question, prompt_template, model_name, clients, la
233
  'explanation': question['explanation'],
234
  'timestamp': datetime.utcnow().isoformat()
235
  }
 
236
  with WRITE_LOCK:
237
- file_exists = os.path.isfile(RESULTS_FILE)
238
- with open(RESULTS_FILE, 'a', encoding='utf-8', newline='') as f:
239
- pd.DataFrame([result]).to_csv(f, header=not file_exists, index=False)
 
 
 
 
240
 
241
  return result
242
 
@@ -245,8 +264,7 @@ def process_evaluations_concurrently(questions, prompt_template, models_to_evalu
245
  total_iterations = len(models_to_evaluate) * len(questions)
246
  current_iteration = 0
247
 
248
- # Load existing results to avoid re-processing
249
- if os.path.exists(RESULTS_FILE):
250
  existing_df = pd.read_csv(RESULTS_FILE)
251
  completed = set(zip(existing_df['model'], existing_df['question']))
252
  else:
@@ -283,7 +301,7 @@ def main():
283
  if 'all_results' not in st.session_state:
284
  st.session_state.all_results = {}
285
  st.session_state.last_evaluated_dataset = None
286
- if os.path.exists(RESULTS_FILE):
287
  existing_df = pd.read_csv(RESULTS_FILE)
288
  all_results = {}
289
  for _, row in existing_df.iterrows():
@@ -302,18 +320,14 @@ def main():
302
 
303
  with st.sidebar:
304
  if st.button("Reset Results"):
305
- if os.path.exists(RESULTS_FILE):
306
- os.remove(RESULTS_FILE)
307
- for file in os.listdir(DATA_DIR):
308
- file_path = os.path.join(DATA_DIR, file)
309
- try:
310
- if os.path.isfile(file_path):
311
- os.unlink(file_path)
312
- except Exception as e:
313
- st.error(f"Error deleting file {file_path}: {e}")
314
- st.session_state.all_results = {}
315
- st.session_state.last_evaluated_dataset = None
316
- st.success("Results have been reset.")
317
  else:
318
  st.info("No results to reset.")
319
 
@@ -334,7 +348,6 @@ def main():
334
 
335
  models_to_evaluate = selected_models
336
 
337
-
338
  default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
339
  Question: {question}
340
 
@@ -358,7 +371,7 @@ Important:
358
  - Only the "answer" field will be used for evaluation
359
  - Ensure your response is in valid JSON format'''
360
 
361
-
362
  col1, col2 = st.columns([2, 1])
363
  with col1:
364
  prompt_template = st.text_area(
@@ -375,74 +388,90 @@ Important:
375
  - `{options}`: The multiple choice options
376
  """)
377
 
 
 
 
 
 
 
 
378
 
379
- with st.spinner("Loading dataset..."):
380
- questions = load_dataset_by_name(selected_dataset)
381
-
382
-
383
- subjects = sorted(list(set(q['subject_name'] for q in questions)))
384
- selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
385
-
386
- if selected_subject != "All":
387
- questions = [q for q in questions if q['subject_name'] == selected_subject]
388
-
389
-
390
- num_questions = st.number_input("Number of questions to evaluate", min_value=1, max_value=len(questions), value=1, step=1)
391
-
392
-
393
- if st.button("Start Evaluation"):
394
- with st.spinner("Starting evaluation..."):
395
- selected_questions = questions[:num_questions]
396
-
397
-
398
- clients = {
399
- "togetherai": st.session_state["togetherai_client"],
400
- "openai": st.session_state["openai_client"],
401
- "anthropic": st.session_state["anthropic_client"]
402
- }
403
-
404
-
405
- last_evaluated_dataset = st.session_state.last_evaluated_dataset if st.session_state.last_evaluated_dataset else selected_dataset
406
-
407
- progress_container = st.container()
408
- progress_bar = progress_container.progress(0)
409
- status_text = progress_container.empty()
410
-
411
- def update_progress(current, total):
412
- progress = current / total
413
- progress_bar.progress(progress)
414
- status_text.text(f"Progress: {current}/{total} evaluations completed")
415
-
416
- results = process_evaluations_concurrently(
417
- selected_questions,
418
- prompt_template,
419
- models_to_evaluate,
420
- update_progress,
421
- clients,
422
- last_evaluated_dataset
423
- )
424
 
425
- all_results = st.session_state.all_results.copy()
426
- for result in results:
427
- model = result.pop('model')
428
- if model not in all_results:
429
- all_results[model] = []
430
- all_results[model].append(result)
431
-
432
- st.session_state.all_results = all_results
433
- st.session_state.last_evaluated_dataset = selected_dataset
434
-
435
- if st.session_state.detailed_model is None and all_results:
436
- st.session_state.detailed_model = list(all_results.keys())[0]
437
- if st.session_state.detailed_dataset is None:
438
- st.session_state.detailed_dataset = selected_dataset
439
 
440
- st.success("Evaluation completed!")
441
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  if st.session_state.all_results:
444
  st.subheader("Evaluation Results")
445
-
 
446
  for model_name, results in st.session_state.all_results.items():
447
  df = pd.DataFrame(results)
448
  metrics = {
@@ -450,27 +479,28 @@ Important:
450
  }
451
  model_metrics[model_name] = metrics
452
 
453
- metrics_df = pd.DataFrame(model_metrics).T
454
 
455
  st.subheader("Model Performance Comparison")
456
  accuracy_chart = alt.Chart(
457
- metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy'])
458
  ).mark_bar().encode(
459
- x=alt.X('index:N', title=None, axis=None),
460
- y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
461
- color=alt.Color('index:N', scale=alt.Scale(scheme='blues')),
462
- tooltip=['index:N', 'value:Q']
463
  ).properties(
464
  height=300,
465
  title={
466
  "text": "Model Accuracy",
467
- "baseline": "bottom",
468
- "orient": "bottom",
469
- "dy": 20
470
  }
471
- )
472
 
473
  st.altair_chart(accuracy_chart, use_container_width=True)
 
 
474
  if st.session_state.all_results:
475
  st.subheader("Detailed Results")
476
 
@@ -494,12 +524,12 @@ Important:
494
  with col2:
495
  selected_dataset_details = st.selectbox(
496
  "Select dataset",
497
- options=[st.session_state.last_evaluated_dataset],
498
  key="dataset_select",
499
  on_change=update_dataset
500
  )
501
 
502
- if selected_model_details in st.session_state.all_results:
503
  results = st.session_state.all_results[selected_model_details]
504
  df = pd.DataFrame(results)
505
  accuracy = df['is_correct'].mean()
@@ -537,21 +567,17 @@ Important:
537
 
538
  st.markdown("---")
539
  st.subheader("Download Results")
540
- all_data = []
541
- for model_name, results in st.session_state.all_results.items():
542
- for result in results:
543
- row = result.copy()
544
- all_data.append(row)
545
-
546
- complete_df = pd.DataFrame(all_data)
547
- csv = complete_df.to_csv(index=False)
548
- st.download_button(
549
- label="Download All Results as CSV",
550
- data=csv,
551
- file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv",
552
- mime="text/csv",
553
- key="download_all_results"
554
- )
555
 
556
  if __name__ == "__main__":
557
  main()
 
18
  import hashlib
19
  from uuid import uuid4
20
  from datetime import datetime
21
+ from huggingface_hub import CommitScheduler, Repository
22
+ from pathlib import Path
23
 
24
  load_dotenv()
25
 
26
  st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")
27
 
 
28
  WRITE_LOCK = threading.Lock()
29
+ DATA_DIR = Path("data")
30
+ DATA_DIR.mkdir(exist_ok=True)
31
+ RESULTS_FILE = DATA_DIR / "results.csv"
32
+
33
+
34
+ scheduler = CommitScheduler(
35
+ repo_id=os.getenv("HF_REPO_ID"),
36
+ repo_type="dataset",
37
+ folder_path=DATA_DIR,
38
+ path_in_repo="data",
39
+ every=10,
40
+ token=os.getenv("HF_TOKEN")
41
+ )
42
 
43
  def initialize_session_state():
44
  if 'api_configured' not in st.session_state:
 
59
  st.session_state.last_evaluated_dataset = None
60
 
61
  initialize_session_state()
 
62
 
63
+ def setup_api_clients():
64
  with st.sidebar:
65
  st.title("API Configuration")
66
 
 
71
  password = st.text_input("Password", type="password")
72
 
73
  if st.button("Verify Credentials"):
74
+ stored_username = os.getenv("STREAMLIT_USERNAME", "")
75
+ stored_password = os.getenv("STREAMLIT_PASSWORD", "")
76
+
77
+ if (hmac.compare_digest(username, stored_username) and
78
+ hmac.compare_digest(password, stored_password)):
79
  try:
80
  st.session_state.togetherai_client = OpenAI(
81
  api_key=os.getenv('TOGETHERAI_API_KEY'),
 
87
  st.session_state.anthropic_client = Anthropic(
88
  api_key=os.getenv('ANTHROPIC_API_KEY')
89
  )
90
+ genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
91
 
92
  st.session_state.api_configured = True
93
  st.success("Successfully configured the API clients with stored keys!")
 
125
  st.session_state.api_configured = False
126
 
127
  setup_api_clients()
128
+
129
+
130
+ scheduler.start()
131
+
132
  MAX_CONCURRENT_CALLS = 5
133
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
134
 
 
203
  )
204
  response_text = chat_session.send_message(prompt).text
205
 
206
+ # Extract JSON from response
207
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
208
  if not json_match:
209
  return f"Error: Invalid response format", response_text
 
219
  except Exception as e:
220
  return f"Error: {str(e)}", str(e)
221
 
222
+
223
  def evaluate_response(model_response, correct_answer):
224
  if model_response.startswith("Error:"):
225
  return False
 
247
  'explanation': question['explanation'],
248
  'timestamp': datetime.utcnow().isoformat()
249
  }
250
+
251
  with WRITE_LOCK:
252
+ if RESULTS_FILE.exists():
253
+ existing_df = pd.read_csv(RESULTS_FILE)
254
+ updated_df = existing_df.append(result, ignore_index=True)
255
+ else:
256
+ updated_df = pd.DataFrame([result])
257
+
258
+ updated_df.to_csv(RESULTS_FILE, index=False)
259
 
260
  return result
261
 
 
264
  total_iterations = len(models_to_evaluate) * len(questions)
265
  current_iteration = 0
266
 
267
+ if RESULTS_FILE.exists():
 
268
  existing_df = pd.read_csv(RESULTS_FILE)
269
  completed = set(zip(existing_df['model'], existing_df['question']))
270
  else:
 
301
  if 'all_results' not in st.session_state:
302
  st.session_state.all_results = {}
303
  st.session_state.last_evaluated_dataset = None
304
+ if RESULTS_FILE.exists():
305
  existing_df = pd.read_csv(RESULTS_FILE)
306
  all_results = {}
307
  for _, row in existing_df.iterrows():
 
320
 
321
  with st.sidebar:
322
  if st.button("Reset Results"):
323
+ if RESULTS_FILE.exists():
324
+ try:
325
+ RESULTS_FILE.unlink()
326
+ st.session_state.all_results = {}
327
+ st.session_state.last_evaluated_dataset = None
328
+ st.success("Results have been reset.")
329
+ except Exception as e:
330
+ st.error(f"Error deleting file: {str(e)}")
 
 
 
 
331
  else:
332
  st.info("No results to reset.")
333
 
 
348
 
349
  models_to_evaluate = selected_models
350
 
 
351
  default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
352
  Question: {question}
353
 
 
371
  - Only the "answer" field will be used for evaluation
372
  - Ensure your response is in valid JSON format'''
373
 
374
+ # Customize Prompt Template
375
  col1, col2 = st.columns([2, 1])
376
  with col1:
377
  prompt_template = st.text_area(
 
388
  - `{options}`: The multiple choice options
389
  """)
390
 
391
+ # Load Dataset
392
+ if st.session_state.api_configured:
393
+ with st.spinner("Loading dataset..."):
394
+ questions = load_dataset_by_name(selected_dataset)
395
+ else:
396
+ st.warning("Please configure the API keys in the sidebar to load datasets and proceed.")
397
+ questions = []
398
 
399
+ # Filter by Subject
400
+ if questions:
401
+ subjects = sorted(list(set(q['subject_name'] for q in questions)))
402
+ selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
+ if selected_subject != "All":
405
+ questions = [q for q in questions if q['subject_name'] == selected_subject]
406
+
407
+ # Number of Questions to Evaluate
408
+ num_questions = st.number_input(
409
+ "Number of questions to evaluate",
410
+ min_value=1,
411
+ max_value=len(questions),
412
+ value=min(10, len(questions)),
413
+ step=1
414
+ )
 
 
 
415
 
416
+ # Start Evaluation Button
417
+ if st.button("Start Evaluation"):
418
+ if not models_to_evaluate:
419
+ st.error("Please select at least one model to evaluate.")
420
+ else:
421
+ with st.spinner("Starting evaluation..."):
422
+ selected_questions = questions[:num_questions]
423
+
424
+ clients = {
425
+ "togetherai": st.session_state["togetherai_client"],
426
+ "openai": st.session_state["openai_client"],
427
+ "anthropic": st.session_state["anthropic_client"]
428
+ }
429
+
430
+ last_evaluated_dataset = st.session_state.last_evaluated_dataset if st.session_state.last_evaluated_dataset else selected_dataset
431
 
432
+ progress_container = st.container()
433
+ progress_bar = progress_container.progress(0)
434
+ status_text = progress_container.empty()
435
+
436
+ def update_progress(current, total):
437
+ progress = current / total
438
+ progress_bar.progress(progress)
439
+ status_text.text(f"Progress: {current}/{total} evaluations completed")
440
+
441
+ results = process_evaluations_concurrently(
442
+ selected_questions,
443
+ prompt_template,
444
+ models_to_evaluate,
445
+ update_progress,
446
+ clients,
447
+ last_evaluated_dataset
448
+ )
449
+
450
+ # Update Session State with New Results
451
+ all_results = st.session_state.all_results.copy()
452
+ for result in results:
453
+ model = result.pop('model')
454
+ if model not in all_results:
455
+ all_results[model] = []
456
+ all_results[model].append(result)
457
+
458
+ st.session_state.all_results = all_results
459
+ st.session_state.last_evaluated_dataset = selected_dataset
460
+
461
+ # Set Default Detailed Model and Dataset if Not Set
462
+ if st.session_state.detailed_model is None and all_results:
463
+ st.session_state.detailed_model = list(all_results.keys())[0]
464
+ if st.session_state.detailed_dataset is None:
465
+ st.session_state.detailed_dataset = selected_dataset
466
+
467
+ st.success("Evaluation completed!")
468
+ st.experimental_rerun()
469
+
470
+ # Display Evaluation Results
471
  if st.session_state.all_results:
472
  st.subheader("Evaluation Results")
473
+ model_metrics = {}
474
+
475
  for model_name, results in st.session_state.all_results.items():
476
  df = pd.DataFrame(results)
477
  metrics = {
 
479
  }
480
  model_metrics[model_name] = metrics
481
 
482
+ metrics_df = pd.DataFrame(model_metrics).T.reset_index().rename(columns={'index': 'Model'})
483
 
484
  st.subheader("Model Performance Comparison")
485
  accuracy_chart = alt.Chart(
486
+ metrics_df
487
  ).mark_bar().encode(
488
+ x=alt.X('Model:N', title=None),
489
+ y=alt.Y('Accuracy:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
490
+ color=alt.Color('Model:N', scale=alt.Scale(scheme='blues')),
491
+ tooltip=['Model:N', 'Accuracy:Q']
492
  ).properties(
493
  height=300,
494
  title={
495
  "text": "Model Accuracy",
496
+ "anchor": "middle",
497
+ "fontSize": 20
 
498
  }
499
+ ).interactive()
500
 
501
  st.altair_chart(accuracy_chart, use_container_width=True)
502
+
503
+ # Display Detailed Results
504
  if st.session_state.all_results:
505
  st.subheader("Detailed Results")
506
 
 
524
  with col2:
525
  selected_dataset_details = st.selectbox(
526
  "Select dataset",
527
+ options=[st.session_state.last_evaluated_dataset] if st.session_state.last_evaluated_dataset else [],
528
  key="dataset_select",
529
  on_change=update_dataset
530
  )
531
 
532
+ if selected_model_details and selected_model_details in st.session_state.all_results:
533
  results = st.session_state.all_results[selected_model_details]
534
  df = pd.DataFrame(results)
535
  accuracy = df['is_correct'].mean()
 
567
 
568
  st.markdown("---")
569
  st.subheader("Download Results")
570
+ if RESULTS_FILE.exists():
571
+ csv_data = RESULTS_FILE.read_text(encoding='utf-8')
572
+ st.download_button(
573
+ label="Download All Results as CSV",
574
+ data=csv_data,
575
+ file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv",
576
+ mime="text/csv",
577
+ key="download_all_results"
578
+ )
579
+ else:
580
+ st.info("No data available to download.")
 
 
 
 
581
 
582
  if __name__ == "__main__":
583
  main()