ehagey's picture
Update app.py
b5ac215 verified
import streamlit as st
import pandas as pd
from dotenv import load_dotenv
from datasets import load_dataset
import json
import re
from openai import OpenAI
import os
from config import DATASETS, MODELS
import matplotlib.pyplot as plt
import altair as alt
from concurrent.futures import ThreadPoolExecutor, as_completed
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
import threading
from anthropic import Anthropic
import google.generativeai as genai
import hmac
import hashlib
load_dotenv()
def initialize_session_state():
if 'api_configured' not in st.session_state:
st.session_state.api_configured = False
if 'togetherai_client' not in st.session_state:
st.session_state.togetherai_client = None
if 'openai_client' not in st.session_state:
st.session_state.openai_client = None
if 'anthropic_client' not in st.session_state:
st.session_state.anthropic_client = None
def setup_api_clients():
initialize_session_state()
with st.sidebar:
st.title("API Configuration")
use_stored = st.checkbox("Use the stored API keys")
if use_stored:
username = st.text_input("Username")
password = st.text_input("Password", type="password")
if st.button("Verify Credentials"):
if (hmac.compare_digest(username, os.environ.get("STREAMLIT_USERNAME", "")) and
hmac.compare_digest(password, os.environ.get("STREAMLIT_PASSWORD", ""))):
st.session_state.togetherai_client = OpenAI(
api_key=os.getenv('TOGETHERAI_API_KEY'),
base_url="https://api.together.xyz/v1"
)
st.session_state.openai_client = OpenAI(
api_key=os.getenv('OPENAI_API_KEY')
)
st.session_state.anthropic_client = Anthropic(
api_key=os.getenv('ANTHROPIC_API_KEY')
)
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
st.session_state.api_configured = True
st.success("Successfully configured the API clients with stored keys!")
else:
st.error("Invalid credentials. Please try again or use your own API keys.")
st.session_state.api_configured = False
else:
st.subheader("Enter Your API Keys")
togetherai_key = st.text_input("Together AI API Key", type="password", key="togetherai_key")
openai_key = st.text_input("OpenAI API Key", type="password", key="openai_key")
anthropic_key = st.text_input("Anthropic API Key", type="password", key="anthropic_key")
gemini_key = st.text_input("Gemini API Key", type="password", key="gemini_key")
if st.button("Initialize with the provided keys"):
try:
st.session_state.togetherai_client = OpenAI(
api_key=togetherai_key,
base_url="https://api.together.xyz/v1"
)
st.session_state.openai_client = OpenAI(
api_key=openai_key
)
st.session_state.anthropic_client = Anthropic(
api_key=anthropic_key
)
genai.configure(api_key=gemini_key)
st.session_state.api_configured = True
st.success("Successfully configured the API clients with provided keys!")
except Exception as e:
st.error(f"Error initializing API clients: {str(e)}")
st.session_state.api_configured = False
MAX_CONCURRENT_CALLS = 5
semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
@st.cache_data
def load_dataset_by_name(dataset_name, split="train"):
dataset_config = DATASETS[dataset_name]
dataset = load_dataset(dataset_config["loader"])
df = pd.DataFrame(dataset[split])
df = df[df['choice_type'] == 'single']
questions = []
for _, row in df.iterrows():
options = [row['opa'], row['opb'], row['opc'], row['opd']]
correct_answer = options[row['cop']]
question_dict = {
'question': row['question'],
'options': options,
'correct_answer': correct_answer,
'subject_name': row['subject_name'],
'topic_name': row['topic_name'],
'explanation': row['exp']
}
questions.append(question_dict)
st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}")
return questions
@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(5),
retry=retry_if_exception_type(Exception)
)
def get_model_response(question, options, prompt_template, model_name, clients):
with semaphore:
try:
model_config = MODELS[model_name]
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
prompt = prompt_template.replace("{question}", question).replace("{options}", options_text)
provider = model_config["provider"]
if provider == "togetherai":
response = clients["togetherai"].chat.completions.create(
model=model_config["model_id"],
messages=[{"role": "user", "content": prompt}]
)
response_text = response.choices[0].message.content.strip()
elif provider == "openai":
response = clients["openai"].chat.completions.create(
model=model_config["model_id"],
messages=[{"role": "user", "content": prompt}]
)
response_text = response.choices[0].message.content.strip()
elif provider == "anthropic":
response = clients["anthropic"].messages.create(
model=model_config["model_id"],
messages=[{"role": "user", "content": prompt}],
max_tokens=4096
)
response_text = response.content[0].text
elif provider == "google":
model = genai.GenerativeModel(
model_name=model_config["model_id"]
)
chat_session = model.start_chat(
history=[]
)
response_text = chat_session.send_message(prompt).text
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if not json_match:
return f"Error: Invalid response format", response_text
json_response = json.loads(json_match.group(0))
answer = json_response.get('answer', '').strip()
answer = re.sub(r'^[A-D]\.\s*', '', answer)
if not any(answer.lower() == opt.lower() for opt in options):
return f"Error: Answer '{answer}' does not match any options", response_text
return answer, response_text
except Exception as e:
return f"Error: {str(e)}", str(e)
def evaluate_response(model_response, correct_answer):
if model_response.startswith("Error:"):
return False
is_correct = model_response.lower().strip() == correct_answer.lower().strip()
return is_correct
def process_single_evaluation(question, prompt_template, model_name, clients):
answer, response_text = get_model_response(
question['question'],
question['options'],
prompt_template,
model_name,
clients
)
is_correct = evaluate_response(answer, question['correct_answer'])
return {
'question': question['question'],
'options': question['options'],
'model_response': answer,
'raw_llm_response': response_text,
'prompt_sent': prompt_template,
'correct_answer': question['correct_answer'],
'subject': question['subject_name'],
'is_correct': is_correct,
'explanation': question['explanation'],
'model_name': model_name
}
def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback, clients):
results = []
total_iterations = len(models_to_evaluate) * len(questions)
current_iteration = 0
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor:
future_to_params = {}
for model_name in models_to_evaluate:
for question in questions:
future = executor.submit(process_single_evaluation, question, prompt_template, model_name, clients)
future_to_params[future] = (model_name, question)
for future in as_completed(future_to_params):
result = future.result()
results.append(result)
current_iteration += 1
progress_callback(current_iteration, total_iterations)
return results
def main():
st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")
initialize_session_state()
setup_api_clients()
if not st.session_state.api_configured:
st.warning("Please configure API keys in the sidebar to proceed")
st.stop()
if 'all_results' not in st.session_state:
st.session_state.all_results = {}
if 'detailed_model' not in st.session_state:
st.session_state.detailed_model = None
if 'detailed_dataset' not in st.session_state:
st.session_state.detailed_dataset = None
if 'last_evaluated_dataset' not in st.session_state:
st.session_state.last_evaluated_dataset = None
col1, col2 = st.columns(2)
with col1:
selected_dataset = st.selectbox(
"Select Dataset",
options=list(DATASETS.keys()),
help="Choose the dataset to evaluate on"
)
with col2:
selected_model = st.multiselect(
"Select Model(s)",
options=list(MODELS.keys()),
default=[list(MODELS.keys())[0]],
help="Choose one or more models to evaluate."
)
models_to_evaluate = selected_model
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
Question: {question}
Options:
{options}
## Output Format:
Please provide your answer in JSON format that contains an "answer" field.
You may include any additional fields in your JSON response that you find relevant, such as:
- "choice reasoning": your detailed reasoning
- "elimination reasoning": why you ruled out other options
Example response format:
{
"answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)",
"choice reasoning": "your detailed reasoning here",
"elimination reasoning": "why you ruled out other options"
}
Important:
- Only the "answer" field will be used for evaluation
- Ensure your response is in valid JSON format'''
col1, col2 = st.columns([2, 1])
with col1:
prompt_template = st.text_area(
"Customize Prompt Template",
default_prompt,
height=400,
help="The below prompt is editable. Please feel free to edit it before your run."
)
with col2:
st.markdown("""
### Prompt Variables
- `{question}`: The medical question
- `{options}`: The multiple choice options
""")
with st.spinner("Loading dataset..."):
questions = load_dataset_by_name(selected_dataset)
subjects = sorted(list(set(q['subject_name'] for q in questions)))
selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
if selected_subject != "All":
questions = [q for q in questions if q['subject_name'] == selected_subject]
num_questions = st.number_input("Number of questions to evaluate", 1, len(questions))
if st.button("Start Evaluation"):
with st.spinner("Starting evaluation..."):
selected_questions = questions[:num_questions]
# Create a clients dictionary
clients = {
"togetherai": st.session_state["togetherai_client"],
"openai": st.session_state["openai_client"],
"anthropic": st.session_state["anthropic_client"]
}
progress_container = st.container()
progress_bar = progress_container.progress(0)
status_text = progress_container.empty()
def update_progress(current, total):
progress = current / total
progress_bar.progress(progress)
status_text.text(f"Progress: {current}/{total} evaluations completed")
results = process_evaluations_concurrently(
selected_questions,
prompt_template,
models_to_evaluate,
update_progress,
clients
)
all_results = {}
for result in results:
model = result.pop('model_name')
if model not in all_results:
all_results[model] = []
all_results[model].append(result)
st.session_state.all_results = all_results
st.session_state.last_evaluated_dataset = selected_dataset
if st.session_state.detailed_model is None and all_results:
st.session_state.detailed_model = list(all_results.keys())[0]
if st.session_state.detailed_dataset is None:
st.session_state.detailed_dataset = selected_dataset
st.success("Evaluation completed!")
st.rerun()
if st.session_state.all_results:
st.subheader("Evaluation Results")
model_metrics = {}
for model_name, results in st.session_state.all_results.items():
df = pd.DataFrame(results)
metrics = {
'Accuracy': df['is_correct'].mean(),
}
model_metrics[model_name] = metrics
metrics_df = pd.DataFrame(model_metrics).T
st.subheader("Model Performance Comparison")
accuracy_chart = alt.Chart(
metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy'])
).mark_bar().encode(
x=alt.X('index:N', title=None, axis=None),
y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
color=alt.Color('index:N', scale=alt.Scale(scheme='blues')),
tooltip=['index:N', 'value:Q']
).properties(
height=300,
title={
"text": "Model Accuracy",
"baseline": "bottom",
"orient": "bottom",
"dy": 20
}
)
st.altair_chart(accuracy_chart, use_container_width=True)
if st.session_state.all_results:
st.subheader("Detailed Results")
def update_model():
st.session_state.detailed_model = st.session_state.model_select
def update_dataset():
st.session_state.detailed_dataset = st.session_state.dataset_select
col1, col2 = st.columns(2)
with col1:
selected_model_details = st.selectbox(
"Select model",
options=list(st.session_state.all_results.keys()),
key="model_select",
on_change=update_model,
index=list(st.session_state.all_results.keys()).index(st.session_state.detailed_model)
if st.session_state.detailed_model in st.session_state.all_results else 0
)
with col2:
selected_dataset_details = st.selectbox(
"Select dataset",
options=[st.session_state.last_evaluated_dataset],
key="dataset_select",
on_change=update_dataset
)
if selected_model_details in st.session_state.all_results:
results = st.session_state.all_results[selected_model_details]
df = pd.DataFrame(results)
accuracy = df['is_correct'].mean()
st.metric("Accuracy", f"{accuracy:.2%}")
for idx, result in enumerate(results):
with st.expander(f"Question {idx + 1} - {result['subject']}"):
st.write("**Question:**", result['question'])
st.write("**Options:**")
for i, opt in enumerate(result['options']):
st.write(f"{chr(65+i)}. {opt}")
col1, col2 = st.columns(2)
with col1:
st.write("**Prompt Used:**")
st.code(result['prompt_sent'])
with col2:
st.write("**Raw Response:**")
st.code(result['raw_llm_response'])
col1, col2 = st.columns(2)
with col1:
st.write("**Correct Answer:**", result['correct_answer'])
st.write("**Model Answer:**", result['model_response'])
with col2:
if result['is_correct']:
st.success("Correct!")
else:
st.error("Incorrect")
st.write("**Explanation:**", result['explanation'])
else:
st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
st.markdown("---")
all_data = []
for model_name, results in st.session_state.all_results.items():
for question_idx, result in enumerate(results):
row = {
'dataset': st.session_state.last_evaluated_dataset,
'model': model_name,
'question': result['question'],
'correct_answer': result['correct_answer'],
'subject': result['subject'],
'options': ' | '.join(result['options']),
'model_response': result['model_response'],
'is_correct': result['is_correct'],
'explanation': result['explanation']
}
all_data.append(row)
complete_df = pd.DataFrame(all_data)
csv = complete_df.to_csv(index=False)
st.download_button(
label="Download All Results as CSV",
data=csv,
file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv",
mime="text/csv",
key="download_all_results"
)
if __name__ == "__main__":
main()