|
import streamlit as st |
|
import pandas as pd |
|
from dotenv import load_dotenv |
|
from datasets import load_dataset |
|
import json |
|
import re |
|
from openai import OpenAI |
|
import os |
|
from config import DATASETS, MODELS |
|
import matplotlib.pyplot as plt |
|
import altair as alt |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type |
|
import threading |
|
from anthropic import Anthropic |
|
import google.generativeai as genai |
|
import hmac |
|
import hashlib |
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
def check_password(): |
|
def password_entered(): |
|
if hmac.compare_digest(st.session_state["username"], os.environ.get("STREAMLIT_USERNAME", "")) and \ |
|
hmac.compare_digest(st.session_state["password"], os.environ.get("STREAMLIT_PASSWORD", "")): |
|
st.session_state["password_correct"] = True |
|
del st.session_state["password"] |
|
del st.session_state["username"] |
|
else: |
|
st.session_state["password_correct"] = False |
|
|
|
if "password_correct" not in st.session_state: |
|
st.text_input("Username", key="username") |
|
st.text_input("Password", type="password", key="password") |
|
st.button("Login", on_click=password_entered) |
|
return False |
|
|
|
|
|
elif not st.session_state["password_correct"]: |
|
st.text_input("Username", key="username") |
|
st.text_input("Password", type="password", key="password") |
|
st.button("Login", on_click=password_entered) |
|
st.error("User not known or password incorrect") |
|
return False |
|
|
|
return True |
|
|
|
togetherai_client = OpenAI( |
|
api_key=os.getenv('TOGETHERAI_API_KEY'), |
|
base_url="https://api.together.xyz/v1" |
|
) |
|
|
|
openai_client = OpenAI( |
|
api_key=os.getenv('OPENAI_API_KEY') |
|
) |
|
|
|
anthropic_client = Anthropic( |
|
api_key=os.getenv('ANTHROPIC_API_KEY') |
|
) |
|
|
|
genai.configure(api_key=os.environ["GEMINI_API_KEY"]) |
|
|
|
MAX_CONCURRENT_CALLS = 5 |
|
semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS) |
|
|
|
@st.cache_data |
|
def load_dataset_by_name(dataset_name, split="train"): |
|
dataset_config = DATASETS[dataset_name] |
|
dataset = load_dataset(dataset_config["loader"]) |
|
df = pd.DataFrame(dataset[split]) |
|
df = df[df['choice_type'] == 'single'] |
|
|
|
questions = [] |
|
for _, row in df.iterrows(): |
|
options = [row['opa'], row['opb'], row['opc'], row['opd']] |
|
correct_answer = options[row['cop']] |
|
|
|
question_dict = { |
|
'question': row['question'], |
|
'options': options, |
|
'correct_answer': correct_answer, |
|
'subject_name': row['subject_name'], |
|
'topic_name': row['topic_name'], |
|
'explanation': row['exp'] |
|
} |
|
questions.append(question_dict) |
|
|
|
st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}") |
|
return questions |
|
|
|
@retry( |
|
wait=wait_exponential(multiplier=1, min=4, max=10), |
|
stop=stop_after_attempt(5), |
|
retry=retry_if_exception_type(Exception) |
|
) |
|
|
|
def get_model_response(question, options, prompt_template, model_name): |
|
with semaphore: |
|
try: |
|
model_config = MODELS[model_name] |
|
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)]) |
|
prompt = prompt_template.replace("{question}", question).replace("{options}", options_text) |
|
|
|
provider = model_config["provider"] |
|
|
|
if provider == "togetherai": |
|
response = togetherai_client.chat.completions.create( |
|
model=model_config["model_id"], |
|
messages=[{"role": "user", "content": prompt}] |
|
) |
|
response_text = response.choices[0].message.content.strip() |
|
|
|
elif provider == "openai": |
|
response = openai_client.chat.completions.create( |
|
model=model_config["model_id"], |
|
messages=[{ |
|
"role": "user", |
|
"content": prompt}] |
|
) |
|
response_text = response.choices[0].message.content.strip() |
|
|
|
elif provider == "anthropic": |
|
response = anthropic_client.messages.create( |
|
model=model_config["model_id"], |
|
messages=[{"role": "user", "content": prompt}], |
|
max_tokens=4096 |
|
) |
|
response_text = response.content[0].text |
|
|
|
elif provider == "google": |
|
model = genai.GenerativeModel( |
|
model_name=model_config["model_id"] |
|
) |
|
|
|
chat_session = model.start_chat( |
|
history=[] |
|
) |
|
response_text = chat_session.send_message(prompt).text |
|
|
|
json_match = re.search(r'\{.*\}', response_text, re.DOTALL) |
|
if not json_match: |
|
return f"Error: Invalid response format", response_text |
|
|
|
json_response = json.loads(json_match.group(0)) |
|
answer = json_response.get('answer', '').strip() |
|
answer = re.sub(r'^[A-D]\.\s*', '', answer) |
|
|
|
if not any(answer.lower() == opt.lower() for opt in options): |
|
return f"Error: Answer '{answer}' does not match any options", response_text |
|
|
|
return answer, response_text |
|
except Exception as e: |
|
return f"Error: {str(e)}", str(e) |
|
|
|
def evaluate_response(model_response, correct_answer): |
|
if model_response.startswith("Error:"): |
|
return False |
|
is_correct = model_response.lower().strip() == correct_answer.lower().strip() |
|
return is_correct |
|
|
|
def process_single_evaluation(question, prompt_template, model_name): |
|
answer, response_text = get_model_response( |
|
question['question'], |
|
question['options'], |
|
prompt_template, |
|
model_name |
|
) |
|
is_correct = evaluate_response(answer, question['correct_answer']) |
|
return { |
|
'question': question['question'], |
|
'options': question['options'], |
|
'model_response': answer, |
|
'raw_llm_response': response_text, |
|
'prompt_sent': prompt_template, |
|
'correct_answer': question['correct_answer'], |
|
'subject': question['subject_name'], |
|
'is_correct': is_correct, |
|
'explanation': question['explanation'], |
|
'model_name': model_name |
|
} |
|
|
|
def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback): |
|
results = [] |
|
total_iterations = len(models_to_evaluate) * len(questions) |
|
current_iteration = 0 |
|
|
|
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor: |
|
future_to_params = {} |
|
for model_name in models_to_evaluate: |
|
for question in questions: |
|
future = executor.submit(process_single_evaluation, question, prompt_template, model_name) |
|
future_to_params[future] = (model_name, question) |
|
|
|
for future in as_completed(future_to_params): |
|
result = future.result() |
|
results.append(result) |
|
current_iteration += 1 |
|
progress_callback(current_iteration, total_iterations) |
|
|
|
return results |
|
|
|
def main(): |
|
st.set_page_config(page_title="LLM Benchmarking in Healthcare", layout="wide") |
|
|
|
if not check_password(): |
|
st.stop() |
|
st.title("LLM Benchmarking in Healthcare") |
|
|
|
if 'all_results' not in st.session_state: |
|
st.session_state.all_results = {} |
|
if 'detailed_model' not in st.session_state: |
|
st.session_state.detailed_model = None |
|
if 'detailed_dataset' not in st.session_state: |
|
st.session_state.detailed_dataset = None |
|
if 'last_evaluated_dataset' not in st.session_state: |
|
st.session_state.last_evaluated_dataset = None |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
selected_dataset = st.selectbox( |
|
"Select Dataset", |
|
options=list(DATASETS.keys()), |
|
help="Choose the dataset to evaluate on" |
|
) |
|
with col2: |
|
selected_model = st.multiselect( |
|
"Select Model(s)", |
|
options=list(MODELS.keys()), |
|
default=[list(MODELS.keys())[0]], |
|
help="Choose one or more models to evaluate." |
|
) |
|
|
|
models_to_evaluate = selected_model |
|
|
|
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question. |
|
Question: {question} |
|
|
|
Options: |
|
{options} |
|
|
|
## Output Format: |
|
Please provide your answer in JSON format that contains an "answer" field. |
|
You may include any additional fields in your JSON response that you find relevant, such as: |
|
- "choice reasoning": your detailed reasoning |
|
- "elimination reasoning": why you ruled out other options |
|
|
|
Example response format: |
|
{ |
|
"answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)", |
|
"choice reasoning": "your detailed reasoning here", |
|
"elimination reasoning": "why you ruled out other options" |
|
} |
|
|
|
Important: |
|
- Only the "answer" field will be used for evaluation |
|
- Ensure your response is in valid JSON format''' |
|
|
|
col1, col2 = st.columns([2, 1]) |
|
with col1: |
|
prompt_template = st.text_area( |
|
"Customize Prompt Template", |
|
default_prompt, |
|
height=400, |
|
help="The below prompt is editable. Please feel free to edit it before your run." |
|
) |
|
|
|
with col2: |
|
st.markdown(""" |
|
### Prompt Variables |
|
- `{question}`: The medical question |
|
- `{options}`: The multiple choice options |
|
""") |
|
|
|
with st.spinner("Loading dataset..."): |
|
questions = load_dataset_by_name(selected_dataset) |
|
subjects = sorted(list(set(q['subject_name'] for q in questions))) |
|
selected_subject = st.selectbox("Filter by subject", ["All"] + subjects) |
|
|
|
if selected_subject != "All": |
|
questions = [q for q in questions if q['subject_name'] == selected_subject] |
|
|
|
num_questions = st.number_input("Number of questions to evaluate", 1, len(questions)) |
|
|
|
if st.button("Start Evaluation"): |
|
with st.spinner("Starting evaluation..."): |
|
selected_questions = questions[:num_questions] |
|
|
|
progress_container = st.container() |
|
progress_bar = progress_container.progress(0) |
|
status_text = progress_container.empty() |
|
|
|
def update_progress(current, total): |
|
progress = current / total |
|
progress_bar.progress(progress) |
|
status_text.text(f"Progress: {current}/{total} evaluations completed") |
|
|
|
results = process_evaluations_concurrently( |
|
selected_questions, |
|
prompt_template, |
|
models_to_evaluate, |
|
update_progress |
|
) |
|
|
|
all_results = {} |
|
for result in results: |
|
model = result.pop('model_name') |
|
if model not in all_results: |
|
all_results[model] = [] |
|
all_results[model].append(result) |
|
|
|
st.session_state.all_results = all_results |
|
st.session_state.last_evaluated_dataset = selected_dataset |
|
|
|
|
|
if st.session_state.detailed_model is None and all_results: |
|
st.session_state.detailed_model = list(all_results.keys())[0] |
|
if st.session_state.detailed_dataset is None: |
|
st.session_state.detailed_dataset = selected_dataset |
|
|
|
st.success("Evaluation completed!") |
|
st.rerun() |
|
|
|
if st.session_state.all_results: |
|
st.subheader("Evaluation Results") |
|
|
|
model_metrics = {} |
|
for model_name, results in st.session_state.all_results.items(): |
|
df = pd.DataFrame(results) |
|
metrics = { |
|
'Accuracy': df['is_correct'].mean(), |
|
} |
|
model_metrics[model_name] = metrics |
|
|
|
metrics_df = pd.DataFrame(model_metrics).T |
|
|
|
st.subheader("Model Performance Comparison") |
|
|
|
accuracy_chart = alt.Chart( |
|
metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy']) |
|
).mark_bar().encode( |
|
x=alt.X('index:N', title=None, axis=None), |
|
y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])), |
|
color=alt.Color('index:N', scale=alt.Scale(scheme='blues')), |
|
tooltip=['index:N', 'value:Q'] |
|
).properties( |
|
height=300, |
|
title={ |
|
"text": "Model Accuracy", |
|
"baseline": "bottom", |
|
"orient": "bottom", |
|
"dy": 20 |
|
} |
|
) |
|
|
|
st.altair_chart(accuracy_chart, use_container_width=True) |
|
|
|
if st.session_state.all_results: |
|
st.subheader("Detailed Results") |
|
|
|
def update_model(): |
|
st.session_state.detailed_model = st.session_state.model_select |
|
|
|
def update_dataset(): |
|
st.session_state.detailed_dataset = st.session_state.dataset_select |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
selected_model_details = st.selectbox( |
|
"Select model", |
|
options=list(st.session_state.all_results.keys()), |
|
key="model_select", |
|
on_change=update_model, |
|
index=list(st.session_state.all_results.keys()).index(st.session_state.detailed_model) |
|
if st.session_state.detailed_model in st.session_state.all_results else 0 |
|
) |
|
|
|
with col2: |
|
selected_dataset_details = st.selectbox( |
|
"Select dataset", |
|
options=[st.session_state.last_evaluated_dataset], |
|
key="dataset_select", |
|
on_change=update_dataset |
|
) |
|
|
|
if selected_model_details in st.session_state.all_results: |
|
results = st.session_state.all_results[selected_model_details] |
|
df = pd.DataFrame(results) |
|
accuracy = df['is_correct'].mean() |
|
|
|
st.metric("Accuracy", f"{accuracy:.2%}") |
|
|
|
for idx, result in enumerate(results): |
|
with st.expander(f"Question {idx + 1} - {result['subject']}"): |
|
st.write("**Question:**", result['question']) |
|
st.write("**Options:**") |
|
for i, opt in enumerate(result['options']): |
|
st.write(f"{chr(65+i)}. {opt}") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.write("**Prompt Used:**") |
|
st.code(result['prompt_sent']) |
|
with col2: |
|
st.write("**Raw Response:**") |
|
st.code(result['raw_llm_response']) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.write("**Correct Answer:**", result['correct_answer']) |
|
st.write("**Model Answer:**", result['model_response']) |
|
with col2: |
|
if result['is_correct']: |
|
st.success("Correct!") |
|
else: |
|
st.error("Incorrect") |
|
|
|
st.write("**Explanation:**", result['explanation']) |
|
else: |
|
st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.") |
|
|
|
st.markdown("---") |
|
all_data = [] |
|
|
|
for model_name, results in st.session_state.all_results.items(): |
|
for question_idx, result in enumerate(results): |
|
row = { |
|
'dataset': st.session_state.last_evaluated_dataset, |
|
'model': model_name, |
|
'question': result['question'], |
|
'correct_answer': result['correct_answer'], |
|
'subject': result['subject'], |
|
'options': ' | '.join(result['options']), |
|
'model_response': result['model_response'], |
|
'is_correct': result['is_correct'], |
|
'explanation': result['explanation'] |
|
} |
|
all_data.append(row) |
|
|
|
complete_df = pd.DataFrame(all_data) |
|
|
|
csv = complete_df.to_csv(index=False) |
|
|
|
st.download_button( |
|
label="Download All Results as CSV", |
|
data=csv, |
|
file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv", |
|
mime="text/csv", |
|
key="download_all_results" |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |