ehagey's picture
Update app.py
ceec99c verified
raw
history blame
23.1 kB
import streamlit as st
import pandas as pd
from dotenv import load_dotenv
from datasets import load_dataset
import json
import re
from openai import OpenAI
import os
from config import DATASETS, MODELS
import matplotlib.pyplot as plt
import altair as alt
from concurrent.futures import ThreadPoolExecutor, as_completed
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
import threading
from anthropic import Anthropic
import google.generativeai as genai
import hmac
import hashlib
from uuid import uuid4
from datetime import datetime
from huggingface_hub import CommitScheduler, Repository
from pathlib import Path
load_dotenv()
st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")
WRITE_LOCK = threading.Lock()
DATA_DIR = Path("data")
DATA_DIR.mkdir(exist_ok=True)
RESULTS_FILE = DATA_DIR / "results.csv"
scheduler = CommitScheduler(
repo_id=os.getenv("HF_REPO_ID"),
repo_type="dataset",
folder_path=DATA_DIR,
path_in_repo="data",
every=10,
token=os.getenv("HF_TOKEN")
)
def initialize_session_state():
if 'api_configured' not in st.session_state:
st.session_state.api_configured = False
if 'togetherai_client' not in st.session_state:
st.session_state.togetherai_client = None
if 'openai_client' not in st.session_state:
st.session_state.openai_client = None
if 'anthropic_client' not in st.session_state:
st.session_state.anthropic_client = None
if 'all_results' not in st.session_state:
st.session_state.all_results = {}
if 'detailed_model' not in st.session_state:
st.session_state.detailed_model = None
if 'detailed_dataset' not in st.session_state:
st.session_state.detailed_dataset = None
if 'last_evaluated_dataset' not in st.session_state:
st.session_state.last_evaluated_dataset = None
initialize_session_state()
def setup_api_clients():
with st.sidebar:
st.title("API Configuration")
use_stored = st.checkbox("Use the stored API keys")
if use_stored:
username = st.text_input("Username")
password = st.text_input("Password", type="password")
if st.button("Verify Credentials"):
stored_username = os.getenv("STREAMLIT_USERNAME", "")
stored_password = os.getenv("STREAMLIT_PASSWORD", "")
if (hmac.compare_digest(username, stored_username) and
hmac.compare_digest(password, stored_password)):
try:
st.session_state.togetherai_client = OpenAI(
api_key=os.getenv('TOGETHERAI_API_KEY'),
base_url="https://api.together.xyz/v1"
)
st.session_state.openai_client = OpenAI(
api_key=os.getenv('OPENAI_API_KEY')
)
st.session_state.anthropic_client = Anthropic(
api_key=os.getenv('ANTHROPIC_API_KEY')
)
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
st.session_state.api_configured = True
st.success("Successfully configured the API clients with stored keys!")
except Exception as e:
st.error(f"Error initializing API clients: {str(e)}")
st.session_state.api_configured = False
else:
st.error("Invalid credentials. Please try again or use your own API keys.")
st.session_state.api_configured = False
else:
st.subheader("Enter Your API Keys")
togetherai_key = st.text_input("Together AI API Key", type="password", key="togetherai_key")
openai_key = st.text_input("OpenAI API Key", type="password", key="openai_key")
anthropic_key = st.text_input("Anthropic API Key", type="password", key="anthropic_key")
gemini_key = st.text_input("Gemini API Key", type="password", key="gemini_key")
if st.button("Initialize with the provided keys"):
try:
st.session_state.togetherai_client = OpenAI(
api_key=togetherai_key,
base_url="https://api.together.xyz/v1"
)
st.session_state.openai_client = OpenAI(
api_key=openai_key
)
st.session_state.anthropic_client = Anthropic(
api_key=anthropic_key
)
genai.configure(api_key=gemini_key)
st.session_state.api_configured = True
st.success("Successfully configured the API clients with provided keys!")
except Exception as e:
st.error(f"Error initializing API clients: {str(e)}")
st.session_state.api_configured = False
setup_api_clients()
scheduler.start()
MAX_CONCURRENT_CALLS = 5
semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
@st.cache_data
def load_dataset_by_name(dataset_name, split="train"):
dataset_config = DATASETS[dataset_name]
dataset = load_dataset(dataset_config["loader"])
df = pd.DataFrame(dataset[split])
df = df[df['choice_type'] == 'single']
questions = []
for _, row in df.iterrows():
options = [row['opa'], row['opb'], row['opc'], row['opd']]
correct_answer = options[row['cop']]
question_dict = {
'question': row['question'],
'options': options,
'correct_answer': correct_answer,
'subject_name': row['subject_name'],
'topic_name': row['topic_name'],
'explanation': row['exp']
}
questions.append(question_dict)
st.write(f"Loaded {len(questions)} single-select questions from `{dataset_name}`")
return questions
@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(5),
retry=retry_if_exception_type(Exception)
)
def get_model_response(question, options, prompt_template, model_name, clients):
with semaphore:
try:
model_config = MODELS[model_name]
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
prompt = prompt_template.replace("{question}", question).replace("{options}", options_text)
provider = model_config["provider"]
if provider == "togetherai":
response = clients["togetherai"].chat.completions.create(
model=model_config["model_id"],
messages=[{"role": "user", "content": prompt}]
)
response_text = response.choices[0].message.content.strip()
elif provider == "openai":
response = clients["openai"].chat.completions.create(
model=model_config["model_id"],
messages=[{"role": "user", "content": prompt}]
)
response_text = response.choices[0].message.content.strip()
elif provider == "anthropic":
response = clients["anthropic"].messages.create(
model=model_config["model_id"],
messages=[{"role": "user", "content": prompt}],
max_tokens=4096
)
response_text = response.content[0].text
elif provider == "google":
model = genai.GenerativeModel(
model_name=model_config["model_id"]
)
chat_session = model.start_chat(
history=[]
)
response_text = chat_session.send_message(prompt).text
# Extract JSON from response
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if not json_match:
return f"Error: Invalid response format", response_text
json_response = json.loads(json_match.group(0))
answer = json_response.get('answer', '').strip()
answer = re.sub(r'^[A-D]\.\s*', '', answer)
if not any(answer.lower() == opt.lower() for opt in options):
return f"Error: Answer '{answer}' does not match any options", response_text
return answer, response_text
except Exception as e:
return f"Error: {str(e)}", str(e)
def evaluate_response(model_response, correct_answer):
if model_response.startswith("Error:"):
return False
is_correct = model_response.lower().strip() == correct_answer.lower().strip()
return is_correct
def process_single_evaluation(question, prompt_template, model_name, clients, last_evaluated_dataset):
answer, response_text = get_model_response(
question['question'],
question['options'],
prompt_template,
model_name,
clients
)
is_correct = evaluate_response(answer, question['correct_answer'])
result = {
'dataset': last_evaluated_dataset,
'model': model_name,
'question': question['question'],
'correct_answer': question['correct_answer'],
'subject': question['subject_name'],
'options': ' | '.join(question['options']),
'model_response': answer,
'is_correct': is_correct,
'explanation': question['explanation'],
'timestamp': datetime.utcnow().isoformat()
}
with WRITE_LOCK:
if RESULTS_FILE.exists():
existing_df = pd.read_csv(RESULTS_FILE)
updated_df = existing_df.append(result, ignore_index=True)
else:
updated_df = pd.DataFrame([result])
updated_df.to_csv(RESULTS_FILE, index=False)
return result
def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback, clients, last_evaluated_dataset):
results = []
total_iterations = len(models_to_evaluate) * len(questions)
current_iteration = 0
if RESULTS_FILE.exists():
existing_df = pd.read_csv(RESULTS_FILE)
completed = set(zip(existing_df['model'], existing_df['question']))
else:
completed = set()
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor:
future_to_params = {}
for model_name in models_to_evaluate:
for question in questions:
if (model_name, question['question']) in completed:
current_iteration += 1
progress_callback(current_iteration, total_iterations)
continue # Skip already completed evaluations
future = executor.submit(
process_single_evaluation,
question,
prompt_template,
model_name,
clients,
last_evaluated_dataset
)
future_to_params[future] = (model_name, question)
for future in as_completed(future_to_params):
result = future.result()
results.append(result)
current_iteration += 1
progress_callback(current_iteration, total_iterations)
return results
def main():
if 'all_results' not in st.session_state:
st.session_state.all_results = {}
st.session_state.last_evaluated_dataset = None
if RESULTS_FILE.exists():
existing_df = pd.read_csv(RESULTS_FILE)
all_results = {}
for _, row in existing_df.iterrows():
model = row['model']
result = row.to_dict()
if model not in all_results:
all_results[model] = []
all_results[model].append(result)
st.session_state.all_results = all_results
st.session_state.last_evaluated_dataset = existing_df['dataset'].iloc[-1]
st.info(f"Loaded existing results from `{RESULTS_FILE}`.")
else:
st.session_state.all_results = {}
st.session_state.last_evaluated_dataset = None
st.info(f"No existing results found. Ready to start fresh.")
with st.sidebar:
if st.button("Reset Results"):
if RESULTS_FILE.exists():
try:
RESULTS_FILE.unlink()
st.session_state.all_results = {}
st.session_state.last_evaluated_dataset = None
st.success("Results have been reset.")
except Exception as e:
st.error(f"Error deleting file: {str(e)}")
else:
st.info("No results to reset.")
col1, col2 = st.columns(2)
with col1:
selected_dataset = st.selectbox(
"Select Dataset",
options=list(DATASETS.keys()),
help="Choose the dataset to evaluate on"
)
with col2:
selected_models = st.multiselect(
"Select Model(s)",
options=list(MODELS.keys()),
default=[list(MODELS.keys())[0]],
help="Choose one or more models to evaluate."
)
models_to_evaluate = selected_models
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
Question: {question}
Options:
{options}
## Output Format:
Please provide your answer in JSON format that contains an "answer" field.
You may include any additional fields in your JSON response that you find relevant, such as:
- "choice reasoning": your detailed reasoning
- "elimination reasoning": why you ruled out other options
Example response format:
{
"answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)",
"choice reasoning": "your detailed reasoning here",
"elimination reasoning": "why you ruled out other options"
}
Important:
- Only the "answer" field will be used for evaluation
- Ensure your response is in valid JSON format'''
# Customize Prompt Template
col1, col2 = st.columns([2, 1])
with col1:
prompt_template = st.text_area(
"Customize Prompt Template",
default_prompt,
height=400,
help="Edit the prompt template before starting the evaluation."
)
with col2:
st.markdown("""
### Prompt Variables
- `{question}`: The medical question
- `{options}`: The multiple choice options
""")
# Load Dataset
if st.session_state.api_configured:
with st.spinner("Loading dataset..."):
questions = load_dataset_by_name(selected_dataset)
else:
st.warning("Please configure the API keys in the sidebar to load datasets and proceed.")
questions = []
# Filter by Subject
if questions:
subjects = sorted(list(set(q['subject_name'] for q in questions)))
selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
if selected_subject != "All":
questions = [q for q in questions if q['subject_name'] == selected_subject]
# Number of Questions to Evaluate
num_questions = st.number_input(
"Number of questions to evaluate",
min_value=1,
max_value=len(questions),
value=min(10, len(questions)),
step=1
)
# Start Evaluation Button
if st.button("Start Evaluation"):
if not models_to_evaluate:
st.error("Please select at least one model to evaluate.")
else:
with st.spinner("Starting evaluation..."):
selected_questions = questions[:num_questions]
clients = {
"togetherai": st.session_state["togetherai_client"],
"openai": st.session_state["openai_client"],
"anthropic": st.session_state["anthropic_client"]
}
last_evaluated_dataset = st.session_state.last_evaluated_dataset if st.session_state.last_evaluated_dataset else selected_dataset
progress_container = st.container()
progress_bar = progress_container.progress(0)
status_text = progress_container.empty()
def update_progress(current, total):
progress = current / total
progress_bar.progress(progress)
status_text.text(f"Progress: {current}/{total} evaluations completed")
results = process_evaluations_concurrently(
selected_questions,
prompt_template,
models_to_evaluate,
update_progress,
clients,
last_evaluated_dataset
)
# Update Session State with New Results
all_results = st.session_state.all_results.copy()
for result in results:
model = result.pop('model')
if model not in all_results:
all_results[model] = []
all_results[model].append(result)
st.session_state.all_results = all_results
st.session_state.last_evaluated_dataset = selected_dataset
# Set Default Detailed Model and Dataset if Not Set
if st.session_state.detailed_model is None and all_results:
st.session_state.detailed_model = list(all_results.keys())[0]
if st.session_state.detailed_dataset is None:
st.session_state.detailed_dataset = selected_dataset
st.success("Evaluation completed!")
st.experimental_rerun()
# Display Evaluation Results
if st.session_state.all_results:
st.subheader("Evaluation Results")
model_metrics = {}
for model_name, results in st.session_state.all_results.items():
df = pd.DataFrame(results)
metrics = {
'Accuracy': df['is_correct'].mean(),
}
model_metrics[model_name] = metrics
metrics_df = pd.DataFrame(model_metrics).T.reset_index().rename(columns={'index': 'Model'})
st.subheader("Model Performance Comparison")
accuracy_chart = alt.Chart(
metrics_df
).mark_bar().encode(
x=alt.X('Model:N', title=None),
y=alt.Y('Accuracy:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
color=alt.Color('Model:N', scale=alt.Scale(scheme='blues')),
tooltip=['Model:N', 'Accuracy:Q']
).properties(
height=300,
title={
"text": "Model Accuracy",
"anchor": "middle",
"fontSize": 20
}
).interactive()
st.altair_chart(accuracy_chart, use_container_width=True)
# Display Detailed Results
if st.session_state.all_results:
st.subheader("Detailed Results")
def update_model():
st.session_state.detailed_model = st.session_state.model_select
def update_dataset():
st.session_state.detailed_dataset = st.session_state.dataset_select
col1, col2 = st.columns(2)
with col1:
selected_model_details = st.selectbox(
"Select model",
options=list(st.session_state.all_results.keys()),
key="model_select",
on_change=update_model,
index=list(st.session_state.all_results.keys()).index(st.session_state.detailed_model)
if st.session_state.detailed_model in st.session_state.all_results else 0
)
with col2:
selected_dataset_details = st.selectbox(
"Select dataset",
options=[st.session_state.last_evaluated_dataset] if st.session_state.last_evaluated_dataset else [],
key="dataset_select",
on_change=update_dataset
)
if selected_model_details and selected_model_details in st.session_state.all_results:
results = st.session_state.all_results[selected_model_details]
df = pd.DataFrame(results)
accuracy = df['is_correct'].mean()
st.metric("Accuracy", f"{accuracy:.2%}")
for idx, result in enumerate(results):
with st.expander(f"Question {idx + 1} - {result['subject']}"):
st.write("**Question:**", result['question'])
st.write("**Options:**")
for i, opt in enumerate(result['options'].split(' | ')):
st.write(f"{chr(65+i)}. {opt}")
col1, col2 = st.columns(2)
with col1:
st.write("**Model Response:**")
st.code(result.get('model_response', "N/A"))
with col2:
st.write("**Explanation:**")
st.code(result.get('explanation', "N/A"))
col1, col2 = st.columns(2)
with col1:
st.write("**Correct Answer:**", result['correct_answer'])
st.write("**Model Answer:**", result['model_response'])
with col2:
if result['is_correct']:
st.success("Correct!")
else:
st.error("Incorrect")
st.write("**Timestamp:**", result['timestamp'])
else:
st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
st.markdown("---")
st.subheader("Download Results")
if RESULTS_FILE.exists():
csv_data = RESULTS_FILE.read_text(encoding='utf-8')
st.download_button(
label="Download All Results as CSV",
data=csv_data,
file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv",
mime="text/csv",
key="download_all_results"
)
else:
st.info("No data available to download.")
if __name__ == "__main__":
main()