Update app.py
Browse files
app.py
CHANGED
@@ -16,30 +16,12 @@ from anthropic import Anthropic
|
|
16 |
import google.generativeai as genai
|
17 |
import hmac
|
18 |
import hashlib
|
19 |
-
from uuid import uuid4
|
20 |
-
from datetime import datetime
|
21 |
-
from huggingface_hub import CommitScheduler, Repository
|
22 |
-
from pathlib import Path
|
23 |
|
24 |
-
load_dotenv()
|
25 |
-
|
26 |
-
st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")
|
27 |
|
28 |
-
WRITE_LOCK = threading.Lock()
|
29 |
-
DATA_DIR = Path("data")
|
30 |
-
DATA_DIR.mkdir(exist_ok=True)
|
31 |
-
RESULTS_FILE = DATA_DIR / "results.csv"
|
32 |
|
33 |
|
34 |
-
scheduler = CommitScheduler(
|
35 |
-
repo_id=os.getenv("HF_REPO_ID"),
|
36 |
-
repo_type="dataset",
|
37 |
-
folder_path=DATA_DIR,
|
38 |
-
path_in_repo="data",
|
39 |
-
every=10,
|
40 |
-
token=os.getenv("HF_TOKEN")
|
41 |
-
)
|
42 |
|
|
|
43 |
def initialize_session_state():
|
44 |
if 'api_configured' not in st.session_state:
|
45 |
st.session_state.api_configured = False
|
@@ -49,18 +31,10 @@ def initialize_session_state():
|
|
49 |
st.session_state.openai_client = None
|
50 |
if 'anthropic_client' not in st.session_state:
|
51 |
st.session_state.anthropic_client = None
|
52 |
-
if 'all_results' not in st.session_state:
|
53 |
-
st.session_state.all_results = {}
|
54 |
-
if 'detailed_model' not in st.session_state:
|
55 |
-
st.session_state.detailed_model = None
|
56 |
-
if 'detailed_dataset' not in st.session_state:
|
57 |
-
st.session_state.detailed_dataset = None
|
58 |
-
if 'last_evaluated_dataset' not in st.session_state:
|
59 |
-
st.session_state.last_evaluated_dataset = None
|
60 |
-
|
61 |
-
initialize_session_state()
|
62 |
|
63 |
def setup_api_clients():
|
|
|
|
|
64 |
with st.sidebar:
|
65 |
st.title("API Configuration")
|
66 |
|
@@ -71,29 +45,22 @@ def setup_api_clients():
|
|
71 |
password = st.text_input("Password", type="password")
|
72 |
|
73 |
if st.button("Verify Credentials"):
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
|
91 |
-
|
92 |
-
st.session_state.api_configured = True
|
93 |
-
st.success("Successfully configured the API clients with stored keys!")
|
94 |
-
except Exception as e:
|
95 |
-
st.error(f"Error initializing API clients: {str(e)}")
|
96 |
-
st.session_state.api_configured = False
|
97 |
else:
|
98 |
st.error("Invalid credentials. Please try again or use your own API keys.")
|
99 |
st.session_state.api_configured = False
|
@@ -124,11 +91,6 @@ def setup_api_clients():
|
|
124 |
st.error(f"Error initializing API clients: {str(e)}")
|
125 |
st.session_state.api_configured = False
|
126 |
|
127 |
-
setup_api_clients()
|
128 |
-
|
129 |
-
|
130 |
-
scheduler.start()
|
131 |
-
|
132 |
MAX_CONCURRENT_CALLS = 5
|
133 |
semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
|
134 |
|
@@ -154,7 +116,7 @@ def load_dataset_by_name(dataset_name, split="train"):
|
|
154 |
}
|
155 |
questions.append(question_dict)
|
156 |
|
157 |
-
st.write(f"Loaded {len(questions)} single-select questions from
|
158 |
return questions
|
159 |
|
160 |
@retry(
|
@@ -162,6 +124,7 @@ def load_dataset_by_name(dataset_name, split="train"):
|
|
162 |
stop=stop_after_attempt(5),
|
163 |
retry=retry_if_exception_type(Exception)
|
164 |
)
|
|
|
165 |
def get_model_response(question, options, prompt_template, model_name, clients):
|
166 |
with semaphore:
|
167 |
try:
|
@@ -203,7 +166,6 @@ def get_model_response(question, options, prompt_template, model_name, clients):
|
|
203 |
)
|
204 |
response_text = chat_session.send_message(prompt).text
|
205 |
|
206 |
-
# Extract JSON from response
|
207 |
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
208 |
if not json_match:
|
209 |
return f"Error: Invalid response format", response_text
|
@@ -219,14 +181,13 @@ def get_model_response(question, options, prompt_template, model_name, clients):
|
|
219 |
except Exception as e:
|
220 |
return f"Error: {str(e)}", str(e)
|
221 |
|
222 |
-
|
223 |
def evaluate_response(model_response, correct_answer):
|
224 |
if model_response.startswith("Error:"):
|
225 |
return False
|
226 |
is_correct = model_response.lower().strip() == correct_answer.lower().strip()
|
227 |
return is_correct
|
228 |
|
229 |
-
def process_single_evaluation(question, prompt_template, model_name, clients
|
230 |
answer, response_text = get_model_response(
|
231 |
question['question'],
|
232 |
question['options'],
|
@@ -235,57 +196,29 @@ def process_single_evaluation(question, prompt_template, model_name, clients, la
|
|
235 |
clients
|
236 |
)
|
237 |
is_correct = evaluate_response(answer, question['correct_answer'])
|
238 |
-
|
239 |
-
'dataset': last_evaluated_dataset,
|
240 |
-
'model': model_name,
|
241 |
'question': question['question'],
|
|
|
|
|
|
|
|
|
242 |
'correct_answer': question['correct_answer'],
|
243 |
'subject': question['subject_name'],
|
244 |
-
'options': ' | '.join(question['options']),
|
245 |
-
'model_response': answer,
|
246 |
'is_correct': is_correct,
|
247 |
'explanation': question['explanation'],
|
248 |
-
'
|
249 |
}
|
250 |
-
|
251 |
-
with WRITE_LOCK:
|
252 |
-
if RESULTS_FILE.exists():
|
253 |
-
existing_df = pd.read_csv(RESULTS_FILE)
|
254 |
-
updated_df = existing_df.append(result, ignore_index=True)
|
255 |
-
else:
|
256 |
-
updated_df = pd.DataFrame([result])
|
257 |
-
|
258 |
-
updated_df.to_csv(RESULTS_FILE, index=False)
|
259 |
-
|
260 |
-
return result
|
261 |
|
262 |
-
def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback, clients
|
263 |
results = []
|
264 |
total_iterations = len(models_to_evaluate) * len(questions)
|
265 |
current_iteration = 0
|
266 |
|
267 |
-
if RESULTS_FILE.exists():
|
268 |
-
existing_df = pd.read_csv(RESULTS_FILE)
|
269 |
-
completed = set(zip(existing_df['model'], existing_df['question']))
|
270 |
-
else:
|
271 |
-
completed = set()
|
272 |
-
|
273 |
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor:
|
274 |
future_to_params = {}
|
275 |
for model_name in models_to_evaluate:
|
276 |
for question in questions:
|
277 |
-
|
278 |
-
current_iteration += 1
|
279 |
-
progress_callback(current_iteration, total_iterations)
|
280 |
-
continue # Skip already completed evaluations
|
281 |
-
future = executor.submit(
|
282 |
-
process_single_evaluation,
|
283 |
-
question,
|
284 |
-
prompt_template,
|
285 |
-
model_name,
|
286 |
-
clients,
|
287 |
-
last_evaluated_dataset
|
288 |
-
)
|
289 |
future_to_params[future] = (model_name, question)
|
290 |
|
291 |
for future in as_completed(future_to_params):
|
@@ -295,42 +228,25 @@ def process_evaluations_concurrently(questions, prompt_template, models_to_evalu
|
|
295 |
progress_callback(current_iteration, total_iterations)
|
296 |
|
297 |
return results
|
298 |
-
|
299 |
-
|
300 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
if 'all_results' not in st.session_state:
|
302 |
st.session_state.all_results = {}
|
|
|
|
|
|
|
|
|
|
|
303 |
st.session_state.last_evaluated_dataset = None
|
304 |
-
if RESULTS_FILE.exists():
|
305 |
-
existing_df = pd.read_csv(RESULTS_FILE)
|
306 |
-
all_results = {}
|
307 |
-
for _, row in existing_df.iterrows():
|
308 |
-
model = row['model']
|
309 |
-
result = row.to_dict()
|
310 |
-
if model not in all_results:
|
311 |
-
all_results[model] = []
|
312 |
-
all_results[model].append(result)
|
313 |
-
st.session_state.all_results = all_results
|
314 |
-
st.session_state.last_evaluated_dataset = existing_df['dataset'].iloc[-1]
|
315 |
-
st.info(f"Loaded existing results from `{RESULTS_FILE}`.")
|
316 |
-
else:
|
317 |
-
st.session_state.all_results = {}
|
318 |
-
st.session_state.last_evaluated_dataset = None
|
319 |
-
st.info(f"No existing results found. Ready to start fresh.")
|
320 |
-
|
321 |
-
with st.sidebar:
|
322 |
-
if st.button("Reset Results"):
|
323 |
-
if RESULTS_FILE.exists():
|
324 |
-
try:
|
325 |
-
RESULTS_FILE.unlink()
|
326 |
-
st.session_state.all_results = {}
|
327 |
-
st.session_state.last_evaluated_dataset = None
|
328 |
-
st.success("Results have been reset.")
|
329 |
-
except Exception as e:
|
330 |
-
st.error(f"Error deleting file: {str(e)}")
|
331 |
-
else:
|
332 |
-
st.info("No results to reset.")
|
333 |
-
|
334 |
col1, col2 = st.columns(2)
|
335 |
with col1:
|
336 |
selected_dataset = st.selectbox(
|
@@ -339,46 +255,41 @@ def main():
|
|
339 |
help="Choose the dataset to evaluate on"
|
340 |
)
|
341 |
with col2:
|
342 |
-
|
343 |
"Select Model(s)",
|
344 |
options=list(MODELS.keys()),
|
345 |
default=[list(MODELS.keys())[0]],
|
346 |
help="Choose one or more models to evaluate."
|
347 |
)
|
348 |
|
349 |
-
models_to_evaluate =
|
350 |
|
351 |
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
|
352 |
Question: {question}
|
353 |
-
|
354 |
Options:
|
355 |
{options}
|
356 |
-
|
357 |
## Output Format:
|
358 |
Please provide your answer in JSON format that contains an "answer" field.
|
359 |
You may include any additional fields in your JSON response that you find relevant, such as:
|
360 |
- "choice reasoning": your detailed reasoning
|
361 |
- "elimination reasoning": why you ruled out other options
|
362 |
-
|
363 |
Example response format:
|
364 |
{
|
365 |
"answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)",
|
366 |
"choice reasoning": "your detailed reasoning here",
|
367 |
"elimination reasoning": "why you ruled out other options"
|
368 |
}
|
369 |
-
|
370 |
Important:
|
371 |
- Only the "answer" field will be used for evaluation
|
372 |
- Ensure your response is in valid JSON format'''
|
373 |
|
374 |
-
# Customize Prompt Template
|
375 |
col1, col2 = st.columns([2, 1])
|
376 |
with col1:
|
377 |
prompt_template = st.text_area(
|
378 |
"Customize Prompt Template",
|
379 |
default_prompt,
|
380 |
height=400,
|
381 |
-
help="
|
382 |
)
|
383 |
|
384 |
with col2:
|
@@ -388,90 +299,67 @@ Important:
|
|
388 |
- `{options}`: The multiple choice options
|
389 |
""")
|
390 |
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
questions = []
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
step=1
|
414 |
-
)
|
415 |
|
416 |
-
# Start Evaluation Button
|
417 |
-
if st.button("Start Evaluation"):
|
418 |
-
if not models_to_evaluate:
|
419 |
-
st.error("Please select at least one model to evaluate.")
|
420 |
-
else:
|
421 |
-
with st.spinner("Starting evaluation..."):
|
422 |
-
selected_questions = questions[:num_questions]
|
423 |
-
|
424 |
-
clients = {
|
425 |
-
"togetherai": st.session_state["togetherai_client"],
|
426 |
-
"openai": st.session_state["openai_client"],
|
427 |
-
"anthropic": st.session_state["anthropic_client"]
|
428 |
-
}
|
429 |
-
|
430 |
-
last_evaluated_dataset = st.session_state.last_evaluated_dataset if st.session_state.last_evaluated_dataset else selected_dataset
|
431 |
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
results = process_evaluations_concurrently(
|
442 |
-
selected_questions,
|
443 |
-
prompt_template,
|
444 |
-
models_to_evaluate,
|
445 |
-
update_progress,
|
446 |
-
clients,
|
447 |
-
last_evaluated_dataset
|
448 |
-
)
|
449 |
-
|
450 |
-
# Update Session State with New Results
|
451 |
-
all_results = st.session_state.all_results.copy()
|
452 |
-
for result in results:
|
453 |
-
model = result.pop('model')
|
454 |
-
if model not in all_results:
|
455 |
-
all_results[model] = []
|
456 |
-
all_results[model].append(result)
|
457 |
-
|
458 |
-
st.session_state.all_results = all_results
|
459 |
-
st.session_state.last_evaluated_dataset = selected_dataset
|
460 |
-
|
461 |
-
# Set Default Detailed Model and Dataset if Not Set
|
462 |
-
if st.session_state.detailed_model is None and all_results:
|
463 |
-
st.session_state.detailed_model = list(all_results.keys())[0]
|
464 |
-
if st.session_state.detailed_dataset is None:
|
465 |
-
st.session_state.detailed_dataset = selected_dataset
|
466 |
-
|
467 |
-
st.success("Evaluation completed!")
|
468 |
-
st.experimental_rerun()
|
469 |
-
|
470 |
-
# Display Evaluation Results
|
471 |
if st.session_state.all_results:
|
472 |
st.subheader("Evaluation Results")
|
|
|
473 |
model_metrics = {}
|
474 |
-
|
475 |
for model_name, results in st.session_state.all_results.items():
|
476 |
df = pd.DataFrame(results)
|
477 |
metrics = {
|
@@ -479,28 +367,29 @@ Important:
|
|
479 |
}
|
480 |
model_metrics[model_name] = metrics
|
481 |
|
482 |
-
metrics_df = pd.DataFrame(model_metrics).T
|
483 |
|
484 |
st.subheader("Model Performance Comparison")
|
|
|
485 |
accuracy_chart = alt.Chart(
|
486 |
-
metrics_df
|
487 |
).mark_bar().encode(
|
488 |
-
x=alt.X('
|
489 |
-
y=alt.Y('
|
490 |
-
color=alt.Color('
|
491 |
-
tooltip=['
|
492 |
).properties(
|
493 |
height=300,
|
494 |
title={
|
495 |
"text": "Model Accuracy",
|
496 |
-
"
|
497 |
-
"
|
|
|
498 |
}
|
499 |
-
)
|
500 |
|
501 |
st.altair_chart(accuracy_chart, use_container_width=True)
|
502 |
|
503 |
-
# Display Detailed Results
|
504 |
if st.session_state.all_results:
|
505 |
st.subheader("Detailed Results")
|
506 |
|
@@ -524,12 +413,12 @@ Important:
|
|
524 |
with col2:
|
525 |
selected_dataset_details = st.selectbox(
|
526 |
"Select dataset",
|
527 |
-
options=[st.session_state.last_evaluated_dataset]
|
528 |
key="dataset_select",
|
529 |
on_change=update_dataset
|
530 |
)
|
531 |
|
532 |
-
if selected_model_details
|
533 |
results = st.session_state.all_results[selected_model_details]
|
534 |
df = pd.DataFrame(results)
|
535 |
accuracy = df['is_correct'].mean()
|
@@ -540,16 +429,16 @@ Important:
|
|
540 |
with st.expander(f"Question {idx + 1} - {result['subject']}"):
|
541 |
st.write("**Question:**", result['question'])
|
542 |
st.write("**Options:**")
|
543 |
-
for i, opt in enumerate(result['options']
|
544 |
st.write(f"{chr(65+i)}. {opt}")
|
545 |
|
546 |
col1, col2 = st.columns(2)
|
547 |
with col1:
|
548 |
-
st.write("**
|
549 |
-
st.code(result
|
550 |
with col2:
|
551 |
-
st.write("**
|
552 |
-
st.code(result
|
553 |
|
554 |
col1, col2 = st.columns(2)
|
555 |
with col1:
|
@@ -561,23 +450,39 @@ Important:
|
|
561 |
else:
|
562 |
st.error("Incorrect")
|
563 |
|
564 |
-
st.write("**
|
565 |
else:
|
566 |
st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
|
567 |
|
568 |
st.markdown("---")
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
|
582 |
if __name__ == "__main__":
|
583 |
-
main()
|
|
|
16 |
import google.generativeai as genai
|
17 |
import hmac
|
18 |
import hashlib
|
|
|
|
|
|
|
|
|
19 |
|
|
|
|
|
|
|
20 |
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
load_dotenv()
|
25 |
def initialize_session_state():
|
26 |
if 'api_configured' not in st.session_state:
|
27 |
st.session_state.api_configured = False
|
|
|
31 |
st.session_state.openai_client = None
|
32 |
if 'anthropic_client' not in st.session_state:
|
33 |
st.session_state.anthropic_client = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def setup_api_clients():
|
36 |
+
initialize_session_state()
|
37 |
+
|
38 |
with st.sidebar:
|
39 |
st.title("API Configuration")
|
40 |
|
|
|
45 |
password = st.text_input("Password", type="password")
|
46 |
|
47 |
if st.button("Verify Credentials"):
|
48 |
+
if (hmac.compare_digest(username, os.environ.get("STREAMLIT_USERNAME", "")) and
|
49 |
+
hmac.compare_digest(password, os.environ.get("STREAMLIT_PASSWORD", ""))):
|
50 |
+
st.session_state.togetherai_client = OpenAI(
|
51 |
+
api_key=os.getenv('TOGETHERAI_API_KEY'),
|
52 |
+
base_url="https://api.together.xyz/v1"
|
53 |
+
)
|
54 |
+
st.session_state.openai_client = OpenAI(
|
55 |
+
api_key=os.getenv('OPENAI_API_KEY')
|
56 |
+
)
|
57 |
+
st.session_state.anthropic_client = Anthropic(
|
58 |
+
api_key=os.getenv('ANTHROPIC_API_KEY')
|
59 |
+
)
|
60 |
+
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
|
61 |
+
|
62 |
+
st.session_state.api_configured = True
|
63 |
+
st.success("Successfully configured the API clients with stored keys!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
else:
|
65 |
st.error("Invalid credentials. Please try again or use your own API keys.")
|
66 |
st.session_state.api_configured = False
|
|
|
91 |
st.error(f"Error initializing API clients: {str(e)}")
|
92 |
st.session_state.api_configured = False
|
93 |
|
|
|
|
|
|
|
|
|
|
|
94 |
MAX_CONCURRENT_CALLS = 5
|
95 |
semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
|
96 |
|
|
|
116 |
}
|
117 |
questions.append(question_dict)
|
118 |
|
119 |
+
st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}")
|
120 |
return questions
|
121 |
|
122 |
@retry(
|
|
|
124 |
stop=stop_after_attempt(5),
|
125 |
retry=retry_if_exception_type(Exception)
|
126 |
)
|
127 |
+
|
128 |
def get_model_response(question, options, prompt_template, model_name, clients):
|
129 |
with semaphore:
|
130 |
try:
|
|
|
166 |
)
|
167 |
response_text = chat_session.send_message(prompt).text
|
168 |
|
|
|
169 |
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
170 |
if not json_match:
|
171 |
return f"Error: Invalid response format", response_text
|
|
|
181 |
except Exception as e:
|
182 |
return f"Error: {str(e)}", str(e)
|
183 |
|
|
|
184 |
def evaluate_response(model_response, correct_answer):
|
185 |
if model_response.startswith("Error:"):
|
186 |
return False
|
187 |
is_correct = model_response.lower().strip() == correct_answer.lower().strip()
|
188 |
return is_correct
|
189 |
|
190 |
+
def process_single_evaluation(question, prompt_template, model_name, clients):
|
191 |
answer, response_text = get_model_response(
|
192 |
question['question'],
|
193 |
question['options'],
|
|
|
196 |
clients
|
197 |
)
|
198 |
is_correct = evaluate_response(answer, question['correct_answer'])
|
199 |
+
return {
|
|
|
|
|
200 |
'question': question['question'],
|
201 |
+
'options': question['options'],
|
202 |
+
'model_response': answer,
|
203 |
+
'raw_llm_response': response_text,
|
204 |
+
'prompt_sent': prompt_template,
|
205 |
'correct_answer': question['correct_answer'],
|
206 |
'subject': question['subject_name'],
|
|
|
|
|
207 |
'is_correct': is_correct,
|
208 |
'explanation': question['explanation'],
|
209 |
+
'model_name': model_name
|
210 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
+
def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback, clients):
|
213 |
results = []
|
214 |
total_iterations = len(models_to_evaluate) * len(questions)
|
215 |
current_iteration = 0
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor:
|
218 |
future_to_params = {}
|
219 |
for model_name in models_to_evaluate:
|
220 |
for question in questions:
|
221 |
+
future = executor.submit(process_single_evaluation, question, prompt_template, model_name, clients)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
future_to_params[future] = (model_name, question)
|
223 |
|
224 |
for future in as_completed(future_to_params):
|
|
|
228 |
progress_callback(current_iteration, total_iterations)
|
229 |
|
230 |
return results
|
231 |
+
|
|
|
232 |
def main():
|
233 |
+
st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")
|
234 |
+
|
235 |
+
initialize_session_state()
|
236 |
+
setup_api_clients()
|
237 |
+
|
238 |
+
if not st.session_state.api_configured:
|
239 |
+
st.warning("Please configure API keys in the sidebar to proceed")
|
240 |
+
st.stop()
|
241 |
+
|
242 |
if 'all_results' not in st.session_state:
|
243 |
st.session_state.all_results = {}
|
244 |
+
if 'detailed_model' not in st.session_state:
|
245 |
+
st.session_state.detailed_model = None
|
246 |
+
if 'detailed_dataset' not in st.session_state:
|
247 |
+
st.session_state.detailed_dataset = None
|
248 |
+
if 'last_evaluated_dataset' not in st.session_state:
|
249 |
st.session_state.last_evaluated_dataset = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
col1, col2 = st.columns(2)
|
251 |
with col1:
|
252 |
selected_dataset = st.selectbox(
|
|
|
255 |
help="Choose the dataset to evaluate on"
|
256 |
)
|
257 |
with col2:
|
258 |
+
selected_model = st.multiselect(
|
259 |
"Select Model(s)",
|
260 |
options=list(MODELS.keys()),
|
261 |
default=[list(MODELS.keys())[0]],
|
262 |
help="Choose one or more models to evaluate."
|
263 |
)
|
264 |
|
265 |
+
models_to_evaluate = selected_model
|
266 |
|
267 |
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
|
268 |
Question: {question}
|
|
|
269 |
Options:
|
270 |
{options}
|
|
|
271 |
## Output Format:
|
272 |
Please provide your answer in JSON format that contains an "answer" field.
|
273 |
You may include any additional fields in your JSON response that you find relevant, such as:
|
274 |
- "choice reasoning": your detailed reasoning
|
275 |
- "elimination reasoning": why you ruled out other options
|
|
|
276 |
Example response format:
|
277 |
{
|
278 |
"answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)",
|
279 |
"choice reasoning": "your detailed reasoning here",
|
280 |
"elimination reasoning": "why you ruled out other options"
|
281 |
}
|
|
|
282 |
Important:
|
283 |
- Only the "answer" field will be used for evaluation
|
284 |
- Ensure your response is in valid JSON format'''
|
285 |
|
|
|
286 |
col1, col2 = st.columns([2, 1])
|
287 |
with col1:
|
288 |
prompt_template = st.text_area(
|
289 |
"Customize Prompt Template",
|
290 |
default_prompt,
|
291 |
height=400,
|
292 |
+
help="The below prompt is editable. Please feel free to edit it before your run."
|
293 |
)
|
294 |
|
295 |
with col2:
|
|
|
299 |
- `{options}`: The multiple choice options
|
300 |
""")
|
301 |
|
302 |
+
with st.spinner("Loading dataset..."):
|
303 |
+
questions = load_dataset_by_name(selected_dataset)
|
304 |
+
subjects = sorted(list(set(q['subject_name'] for q in questions)))
|
305 |
+
selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
|
306 |
+
|
307 |
+
if selected_subject != "All":
|
308 |
+
questions = [q for q in questions if q['subject_name'] == selected_subject]
|
309 |
+
|
310 |
+
num_questions = st.number_input("Number of questions to evaluate", 1, len(questions))
|
311 |
+
|
312 |
+
if st.button("Start Evaluation"):
|
313 |
+
with st.spinner("Starting evaluation..."):
|
314 |
+
selected_questions = questions[:num_questions]
|
315 |
+
|
316 |
+
# Create a clients dictionary
|
317 |
+
clients = {
|
318 |
+
"togetherai": st.session_state["togetherai_client"],
|
319 |
+
"openai": st.session_state["openai_client"],
|
320 |
+
"anthropic": st.session_state["anthropic_client"]
|
321 |
+
}
|
322 |
+
|
323 |
+
progress_container = st.container()
|
324 |
+
progress_bar = progress_container.progress(0)
|
325 |
+
status_text = progress_container.empty()
|
326 |
+
|
327 |
+
def update_progress(current, total):
|
328 |
+
progress = current / total
|
329 |
+
progress_bar.progress(progress)
|
330 |
+
status_text.text(f"Progress: {current}/{total} evaluations completed")
|
331 |
+
|
332 |
+
results = process_evaluations_concurrently(
|
333 |
+
selected_questions,
|
334 |
+
prompt_template,
|
335 |
+
models_to_evaluate,
|
336 |
+
update_progress,
|
337 |
+
clients
|
338 |
+
)
|
339 |
|
340 |
+
all_results = {}
|
341 |
+
for result in results:
|
342 |
+
model = result.pop('model_name')
|
343 |
+
if model not in all_results:
|
344 |
+
all_results[model] = []
|
345 |
+
all_results[model].append(result)
|
346 |
+
|
347 |
+
st.session_state.all_results = all_results
|
348 |
+
st.session_state.last_evaluated_dataset = selected_dataset
|
|
|
|
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
+
if st.session_state.detailed_model is None and all_results:
|
352 |
+
st.session_state.detailed_model = list(all_results.keys())[0]
|
353 |
+
if st.session_state.detailed_dataset is None:
|
354 |
+
st.session_state.detailed_dataset = selected_dataset
|
355 |
+
|
356 |
+
st.success("Evaluation completed!")
|
357 |
+
st.rerun()
|
358 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
if st.session_state.all_results:
|
360 |
st.subheader("Evaluation Results")
|
361 |
+
|
362 |
model_metrics = {}
|
|
|
363 |
for model_name, results in st.session_state.all_results.items():
|
364 |
df = pd.DataFrame(results)
|
365 |
metrics = {
|
|
|
367 |
}
|
368 |
model_metrics[model_name] = metrics
|
369 |
|
370 |
+
metrics_df = pd.DataFrame(model_metrics).T
|
371 |
|
372 |
st.subheader("Model Performance Comparison")
|
373 |
+
|
374 |
accuracy_chart = alt.Chart(
|
375 |
+
metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy'])
|
376 |
).mark_bar().encode(
|
377 |
+
x=alt.X('index:N', title=None, axis=None),
|
378 |
+
y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
|
379 |
+
color=alt.Color('index:N', scale=alt.Scale(scheme='blues')),
|
380 |
+
tooltip=['index:N', 'value:Q']
|
381 |
).properties(
|
382 |
height=300,
|
383 |
title={
|
384 |
"text": "Model Accuracy",
|
385 |
+
"baseline": "bottom",
|
386 |
+
"orient": "bottom",
|
387 |
+
"dy": 20
|
388 |
}
|
389 |
+
)
|
390 |
|
391 |
st.altair_chart(accuracy_chart, use_container_width=True)
|
392 |
|
|
|
393 |
if st.session_state.all_results:
|
394 |
st.subheader("Detailed Results")
|
395 |
|
|
|
413 |
with col2:
|
414 |
selected_dataset_details = st.selectbox(
|
415 |
"Select dataset",
|
416 |
+
options=[st.session_state.last_evaluated_dataset],
|
417 |
key="dataset_select",
|
418 |
on_change=update_dataset
|
419 |
)
|
420 |
|
421 |
+
if selected_model_details in st.session_state.all_results:
|
422 |
results = st.session_state.all_results[selected_model_details]
|
423 |
df = pd.DataFrame(results)
|
424 |
accuracy = df['is_correct'].mean()
|
|
|
429 |
with st.expander(f"Question {idx + 1} - {result['subject']}"):
|
430 |
st.write("**Question:**", result['question'])
|
431 |
st.write("**Options:**")
|
432 |
+
for i, opt in enumerate(result['options']):
|
433 |
st.write(f"{chr(65+i)}. {opt}")
|
434 |
|
435 |
col1, col2 = st.columns(2)
|
436 |
with col1:
|
437 |
+
st.write("**Prompt Used:**")
|
438 |
+
st.code(result['prompt_sent'])
|
439 |
with col2:
|
440 |
+
st.write("**Raw Response:**")
|
441 |
+
st.code(result['raw_llm_response'])
|
442 |
|
443 |
col1, col2 = st.columns(2)
|
444 |
with col1:
|
|
|
450 |
else:
|
451 |
st.error("Incorrect")
|
452 |
|
453 |
+
st.write("**Explanation:**", result['explanation'])
|
454 |
else:
|
455 |
st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
|
456 |
|
457 |
st.markdown("---")
|
458 |
+
all_data = []
|
459 |
+
|
460 |
+
for model_name, results in st.session_state.all_results.items():
|
461 |
+
for question_idx, result in enumerate(results):
|
462 |
+
row = {
|
463 |
+
'dataset': st.session_state.last_evaluated_dataset,
|
464 |
+
'model': model_name,
|
465 |
+
'question': result['question'],
|
466 |
+
'correct_answer': result['correct_answer'],
|
467 |
+
'subject': result['subject'],
|
468 |
+
'options': ' | '.join(result['options']),
|
469 |
+
'model_response': result['model_response'],
|
470 |
+
'is_correct': result['is_correct'],
|
471 |
+
'explanation': result['explanation']
|
472 |
+
}
|
473 |
+
all_data.append(row)
|
474 |
+
|
475 |
+
complete_df = pd.DataFrame(all_data)
|
476 |
+
|
477 |
+
csv = complete_df.to_csv(index=False)
|
478 |
+
|
479 |
+
st.download_button(
|
480 |
+
label="Download All Results as CSV",
|
481 |
+
data=csv,
|
482 |
+
file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv",
|
483 |
+
mime="text/csv",
|
484 |
+
key="download_all_results"
|
485 |
+
)
|
486 |
|
487 |
if __name__ == "__main__":
|
488 |
+
main()
|