ehagey commited on
Commit
2322bf2
·
verified ·
1 Parent(s): 9ed09e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -89
app.py CHANGED
@@ -16,6 +16,8 @@ from anthropic import Anthropic
16
  import google.generativeai as genai
17
  import hmac
18
  import hashlib
 
 
19
 
20
  load_dotenv()
21
 
@@ -32,16 +34,6 @@ if not os.path.exists(DATA_DIR):
32
  else:
33
  st.info(f"`{DATA_DIR}` directory already exists.")
34
 
35
- if os.path.exists(DATA_DIR):
36
- files = os.listdir(DATA_DIR)
37
- st.write(f"Contents of `{DATA_DIR}` directory:")
38
- if files:
39
- for file in files:
40
- st.write(f"- {file}")
41
- else:
42
- st.write("The data directory is currently empty.")
43
- else:
44
- st.error(f"`{DATA_DIR}` directory does not exist.")
45
 
46
  def initialize_session_state():
47
  if 'api_configured' not in st.session_state:
@@ -62,8 +54,7 @@ def initialize_session_state():
62
  st.session_state.last_evaluated_dataset = None
63
 
64
  def setup_api_clients():
65
- initialize_session_state()
66
-
67
  with st.sidebar:
68
  st.title("API Configuration")
69
 
@@ -76,20 +67,24 @@ def setup_api_clients():
76
  if st.button("Verify Credentials"):
77
  if (hmac.compare_digest(username, os.environ.get("STREAMLIT_USERNAME", "")) and
78
  hmac.compare_digest(password, os.environ.get("STREAMLIT_PASSWORD", ""))):
79
- st.session_state.togetherai_client = OpenAI(
80
- api_key=os.getenv('TOGETHERAI_API_KEY'),
81
- base_url="https://api.together.xyz/v1"
82
- )
83
- st.session_state.openai_client = OpenAI(
84
- api_key=os.getenv('OPENAI_API_KEY')
85
- )
86
- st.session_state.anthropic_client = Anthropic(
87
- api_key=os.getenv('ANTHROPIC_API_KEY')
88
- )
89
- genai.configure(api_key=os.environ["GEMINI_API_KEY"])
90
-
91
- st.session_state.api_configured = True
92
- st.success("Successfully configured the API clients with stored keys!")
 
 
 
 
93
  else:
94
  st.error("Invalid credentials. Please try again or use your own API keys.")
95
  st.session_state.api_configured = False
@@ -120,6 +115,7 @@ def setup_api_clients():
120
  st.error(f"Error initializing API clients: {str(e)}")
121
  st.session_state.api_configured = False
122
 
 
123
  MAX_CONCURRENT_CALLS = 5
124
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
125
 
@@ -145,7 +141,7 @@ def load_dataset_by_name(dataset_name, split="train"):
145
  }
146
  questions.append(question_dict)
147
 
148
- st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}")
149
  return questions
150
 
151
  @retry(
@@ -233,14 +229,13 @@ def process_single_evaluation(question, prompt_template, model_name, clients, la
233
  'options': ' | '.join(question['options']),
234
  'model_response': answer,
235
  'is_correct': is_correct,
236
- 'explanation': question['explanation']
 
237
  }
238
-
239
  with WRITE_LOCK:
240
  file_exists = os.path.isfile(RESULTS_FILE)
241
  with open(RESULTS_FILE, 'a', encoding='utf-8', newline='') as f:
242
- writer = pd.DataFrame([result])
243
- writer.to_csv(f, header=not file_exists, index=False)
244
 
245
  return result
246
 
@@ -264,7 +259,6 @@ def process_evaluations_concurrently(questions, prompt_template, models_to_evalu
264
  current_iteration += 1
265
  progress_callback(current_iteration, total_iterations)
266
  continue # Skip already completed evaluations
267
- # Pass last_evaluated_dataset as an argument
268
  future = executor.submit(
269
  process_single_evaluation,
270
  question,
@@ -283,42 +277,39 @@ def process_evaluations_concurrently(questions, prompt_template, models_to_evalu
283
 
284
  return results
285
 
286
- def main():
287
-
288
- initialize_session_state()
289
- setup_api_clients()
290
-
291
- if not st.session_state.api_configured:
292
- st.warning("Please configure API keys in the sidebar to proceed")
293
- st.stop()
294
 
 
295
  if 'all_results' not in st.session_state:
296
- if os.path.exists(RESULTS_FILE):
297
- existing_df = pd.read_csv(RESULTS_FILE)
298
- all_results = {}
299
- for _, row in existing_df.iterrows():
300
- model = row['model']
301
- result = row.to_dict()
302
- if model not in all_results:
303
- all_results[model] = []
304
- all_results[model].append(result)
305
- st.session_state.all_results = all_results
306
- st.session_state.last_evaluated_dataset = existing_df['dataset'].iloc[-1]
307
- else:
308
- st.session_state.all_results = {}
309
- st.session_state.last_evaluated_dataset = None
310
-
311
- if 'detailed_model' not in st.session_state:
312
- st.session_state.detailed_model = None
313
- if 'detailed_dataset' not in st.session_state:
314
- st.session_state.detailed_dataset = None
315
- if 'last_evaluated_dataset' not in st.session_state:
316
  st.session_state.last_evaluated_dataset = None
 
317
 
318
  with st.sidebar:
319
  if st.button("Reset Results"):
320
  if os.path.exists(RESULTS_FILE):
321
  os.remove(RESULTS_FILE)
 
 
 
 
 
 
 
322
  st.session_state.all_results = {}
323
  st.session_state.last_evaluated_dataset = None
324
  st.success("Results have been reset.")
@@ -333,14 +324,15 @@ def main():
333
  help="Choose the dataset to evaluate on"
334
  )
335
  with col2:
336
- selected_model = st.multiselect(
337
  "Select Model(s)",
338
  options=list(MODELS.keys()),
339
  default=[list(MODELS.keys())[0]],
340
  help="Choose one or more models to evaluate."
341
  )
342
 
343
- models_to_evaluate = selected_model
 
344
 
345
  default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
346
  Question: {question}
@@ -365,13 +357,14 @@ Important:
365
  - Only the "answer" field will be used for evaluation
366
  - Ensure your response is in valid JSON format'''
367
 
 
368
  col1, col2 = st.columns([2, 1])
369
  with col1:
370
  prompt_template = st.text_area(
371
  "Customize Prompt Template",
372
  default_prompt,
373
  height=400,
374
- help="The below prompt is editable. Please feel free to edit it before your run."
375
  )
376
 
377
  with col2:
@@ -381,28 +374,34 @@ Important:
381
  - `{options}`: The multiple choice options
382
  """)
383
 
 
384
  with st.spinner("Loading dataset..."):
385
  questions = load_dataset_by_name(selected_dataset)
 
 
386
  subjects = sorted(list(set(q['subject_name'] for q in questions)))
387
  selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
388
 
389
  if selected_subject != "All":
390
  questions = [q for q in questions if q['subject_name'] == selected_subject]
391
 
392
- num_questions = st.number_input("Number of questions to evaluate", 1, len(questions))
 
 
393
 
394
  if st.button("Start Evaluation"):
395
  with st.spinner("Starting evaluation..."):
396
  selected_questions = questions[:num_questions]
397
 
398
- # Create a clients dictionary
399
  clients = {
400
  "togetherai": st.session_state["togetherai_client"],
401
  "openai": st.session_state["openai_client"],
402
  "anthropic": st.session_state["anthropic_client"]
403
  }
404
 
405
- last_evaluated_dataset = st.session_state.last_evaluated_dataset
 
406
 
407
  progress_container = st.container()
408
  progress_bar = progress_container.progress(0)
@@ -443,7 +442,6 @@ Important:
443
  if st.session_state.all_results:
444
  st.subheader("Evaluation Results")
445
 
446
- model_metrics = {}
447
  for model_name, results in st.session_state.all_results.items():
448
  df = pd.DataFrame(results)
449
  metrics = {
@@ -454,7 +452,6 @@ Important:
454
  metrics_df = pd.DataFrame(model_metrics).T
455
 
456
  st.subheader("Model Performance Comparison")
457
-
458
  accuracy_chart = alt.Chart(
459
  metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy'])
460
  ).mark_bar().encode(
@@ -473,7 +470,6 @@ Important:
473
  )
474
 
475
  st.altair_chart(accuracy_chart, use_container_width=True)
476
-
477
  if st.session_state.all_results:
478
  st.subheader("Detailed Results")
479
 
@@ -518,11 +514,11 @@ Important:
518
 
519
  col1, col2 = st.columns(2)
520
  with col1:
521
- st.write("**Prompt Used:**")
522
- st.code(result.get('prompt_sent', "N/A"))
523
  with col2:
524
- st.write("**Raw Response:**")
525
- st.code(result.get('raw_llm_response', "N/A"))
526
 
527
  col1, col2 = st.columns(2)
528
  with col1:
@@ -534,32 +530,20 @@ Important:
534
  else:
535
  st.error("Incorrect")
536
 
537
- st.write("**Explanation:**", result['explanation'])
538
  else:
539
  st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
540
 
541
  st.markdown("---")
 
542
  all_data = []
543
-
544
  for model_name, results in st.session_state.all_results.items():
545
- for question_idx, result in enumerate(results):
546
- row = {
547
- 'dataset': st.session_state.last_evaluated_dataset,
548
- 'model': model_name,
549
- 'question': result['question'],
550
- 'correct_answer': result['correct_answer'],
551
- 'subject': result['subject'],
552
- 'options': result['options'],
553
- 'model_response': result['model_response'],
554
- 'is_correct': result['is_correct'],
555
- 'explanation': result['explanation']
556
- }
557
  all_data.append(row)
558
 
559
  complete_df = pd.DataFrame(all_data)
560
-
561
  csv = complete_df.to_csv(index=False)
562
-
563
  st.download_button(
564
  label="Download All Results as CSV",
565
  data=csv,
 
16
  import google.generativeai as genai
17
  import hmac
18
  import hashlib
19
+ from uuid import uuid4
20
+ from datetime import datetime
21
 
22
  load_dotenv()
23
 
 
34
  else:
35
  st.info(f"`{DATA_DIR}` directory already exists.")
36
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def initialize_session_state():
39
  if 'api_configured' not in st.session_state:
 
54
  st.session_state.last_evaluated_dataset = None
55
 
56
  def setup_api_clients():
57
+ initialize_session_state()
 
58
  with st.sidebar:
59
  st.title("API Configuration")
60
 
 
67
  if st.button("Verify Credentials"):
68
  if (hmac.compare_digest(username, os.environ.get("STREAMLIT_USERNAME", "")) and
69
  hmac.compare_digest(password, os.environ.get("STREAMLIT_PASSWORD", ""))):
70
+ try:
71
+ st.session_state.togetherai_client = OpenAI(
72
+ api_key=os.getenv('TOGETHERAI_API_KEY'),
73
+ base_url="https://api.together.xyz/v1"
74
+ )
75
+ st.session_state.openai_client = OpenAI(
76
+ api_key=os.getenv('OPENAI_API_KEY')
77
+ )
78
+ st.session_state.anthropic_client = Anthropic(
79
+ api_key=os.getenv('ANTHROPIC_API_KEY')
80
+ )
81
+ genai.configure(api_key=os.environ["GEMINI_API_KEY"])
82
+
83
+ st.session_state.api_configured = True
84
+ st.success("Successfully configured the API clients with stored keys!")
85
+ except Exception as e:
86
+ st.error(f"Error initializing API clients: {str(e)}")
87
+ st.session_state.api_configured = False
88
  else:
89
  st.error("Invalid credentials. Please try again or use your own API keys.")
90
  st.session_state.api_configured = False
 
115
  st.error(f"Error initializing API clients: {str(e)}")
116
  st.session_state.api_configured = False
117
 
118
+ setup_api_clients()
119
  MAX_CONCURRENT_CALLS = 5
120
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
121
 
 
141
  }
142
  questions.append(question_dict)
143
 
144
+ st.write(f"Loaded {len(questions)} single-select questions from `{dataset_name}`")
145
  return questions
146
 
147
  @retry(
 
229
  'options': ' | '.join(question['options']),
230
  'model_response': answer,
231
  'is_correct': is_correct,
232
+ 'explanation': question['explanation'],
233
+ 'timestamp': datetime.utcnow().isoformat()
234
  }
 
235
  with WRITE_LOCK:
236
  file_exists = os.path.isfile(RESULTS_FILE)
237
  with open(RESULTS_FILE, 'a', encoding='utf-8', newline='') as f:
238
+ pd.DataFrame([result]).to_csv(f, header=not file_exists, index=False)
 
239
 
240
  return result
241
 
 
259
  current_iteration += 1
260
  progress_callback(current_iteration, total_iterations)
261
  continue # Skip already completed evaluations
 
262
  future = executor.submit(
263
  process_single_evaluation,
264
  question,
 
277
 
278
  return results
279
 
 
 
 
 
 
 
 
 
280
 
281
+ def main():
282
  if 'all_results' not in st.session_state:
283
+ st.session_state.all_results = {}
284
+ st.session_state.last_evaluated_dataset = None
285
+ if os.path.exists(RESULTS_FILE):
286
+ existing_df = pd.read_csv(RESULTS_FILE)
287
+ all_results = {}
288
+ for _, row in existing_df.iterrows():
289
+ model = row['model']
290
+ result = row.to_dict()
291
+ if model not in all_results:
292
+ all_results[model] = []
293
+ all_results[model].append(result)
294
+ st.session_state.all_results = all_results
295
+ st.session_state.last_evaluated_dataset = existing_df['dataset'].iloc[-1]
296
+ st.info(f"Loaded existing results from `{RESULTS_FILE}`.")
297
+ else:
298
+ st.session_state.all_results = {}
 
 
 
 
299
  st.session_state.last_evaluated_dataset = None
300
+ st.info(f"No existing results found. Ready to start fresh.")
301
 
302
  with st.sidebar:
303
  if st.button("Reset Results"):
304
  if os.path.exists(RESULTS_FILE):
305
  os.remove(RESULTS_FILE)
306
+ for file in os.listdir(DATA_DIR):
307
+ file_path = os.path.join(DATA_DIR, file)
308
+ try:
309
+ if os.path.isfile(file_path):
310
+ os.unlink(file_path)
311
+ except Exception as e:
312
+ st.error(f"Error deleting file {file_path}: {e}")
313
  st.session_state.all_results = {}
314
  st.session_state.last_evaluated_dataset = None
315
  st.success("Results have been reset.")
 
324
  help="Choose the dataset to evaluate on"
325
  )
326
  with col2:
327
+ selected_models = st.multiselect(
328
  "Select Model(s)",
329
  options=list(MODELS.keys()),
330
  default=[list(MODELS.keys())[0]],
331
  help="Choose one or more models to evaluate."
332
  )
333
 
334
+ models_to_evaluate = selected_models
335
+
336
 
337
  default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
338
  Question: {question}
 
357
  - Only the "answer" field will be used for evaluation
358
  - Ensure your response is in valid JSON format'''
359
 
360
+
361
  col1, col2 = st.columns([2, 1])
362
  with col1:
363
  prompt_template = st.text_area(
364
  "Customize Prompt Template",
365
  default_prompt,
366
  height=400,
367
+ help="Edit the prompt template before starting the evaluation."
368
  )
369
 
370
  with col2:
 
374
  - `{options}`: The multiple choice options
375
  """)
376
 
377
+
378
  with st.spinner("Loading dataset..."):
379
  questions = load_dataset_by_name(selected_dataset)
380
+
381
+
382
  subjects = sorted(list(set(q['subject_name'] for q in questions)))
383
  selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
384
 
385
  if selected_subject != "All":
386
  questions = [q for q in questions if q['subject_name'] == selected_subject]
387
 
388
+
389
+ num_questions = st.number_input("Number of questions to evaluate", min_value=1, max_value=len(questions), value=1, step=1)
390
+
391
 
392
  if st.button("Start Evaluation"):
393
  with st.spinner("Starting evaluation..."):
394
  selected_questions = questions[:num_questions]
395
 
396
+
397
  clients = {
398
  "togetherai": st.session_state["togetherai_client"],
399
  "openai": st.session_state["openai_client"],
400
  "anthropic": st.session_state["anthropic_client"]
401
  }
402
 
403
+
404
+ last_evaluated_dataset = st.session_state.last_evaluated_dataset if st.session_state.last_evaluated_dataset else selected_dataset
405
 
406
  progress_container = st.container()
407
  progress_bar = progress_container.progress(0)
 
442
  if st.session_state.all_results:
443
  st.subheader("Evaluation Results")
444
 
 
445
  for model_name, results in st.session_state.all_results.items():
446
  df = pd.DataFrame(results)
447
  metrics = {
 
452
  metrics_df = pd.DataFrame(model_metrics).T
453
 
454
  st.subheader("Model Performance Comparison")
 
455
  accuracy_chart = alt.Chart(
456
  metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy'])
457
  ).mark_bar().encode(
 
470
  )
471
 
472
  st.altair_chart(accuracy_chart, use_container_width=True)
 
473
  if st.session_state.all_results:
474
  st.subheader("Detailed Results")
475
 
 
514
 
515
  col1, col2 = st.columns(2)
516
  with col1:
517
+ st.write("**Model Response:**")
518
+ st.code(result.get('model_response', "N/A"))
519
  with col2:
520
+ st.write("**Explanation:**")
521
+ st.code(result.get('explanation', "N/A"))
522
 
523
  col1, col2 = st.columns(2)
524
  with col1:
 
530
  else:
531
  st.error("Incorrect")
532
 
533
+ st.write("**Timestamp:**", result['timestamp'])
534
  else:
535
  st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
536
 
537
  st.markdown("---")
538
+ st.subheader("Download Results")
539
  all_data = []
 
540
  for model_name, results in st.session_state.all_results.items():
541
+ for result in results:
542
+ row = result.copy()
 
 
 
 
 
 
 
 
 
 
543
  all_data.append(row)
544
 
545
  complete_df = pd.DataFrame(all_data)
 
546
  csv = complete_df.to_csv(index=False)
 
547
  st.download_button(
548
  label="Download All Results as CSV",
549
  data=csv,