import streamlit as st import pandas as pd from dotenv import load_dotenv from datasets import load_dataset import json import re from openai import OpenAI import os from config import DATASETS, MODELS import matplotlib.pyplot as plt import altair as alt from concurrent.futures import ThreadPoolExecutor, as_completed from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type import threading from anthropic import Anthropic import google.generativeai as genai import hmac import hashlib load_dotenv() def check_password(): def password_entered(): if hmac.compare_digest(st.session_state["username"], os.environ.get("STREAMLIT_USERNAME", "")) and \ hmac.compare_digest(st.session_state["password"], os.environ.get("STREAMLIT_PASSWORD", "")): st.session_state["password_correct"] = True del st.session_state["password"] del st.session_state["username"] else: st.session_state["password_correct"] = False if "password_correct" not in st.session_state: st.text_input("Username", key="username") st.text_input("Password", type="password", key="password") st.button("Login", on_click=password_entered) return False # Password correct elif not st.session_state["password_correct"]: st.text_input("Username", key="username") st.text_input("Password", type="password", key="password") st.button("Login", on_click=password_entered) st.error("User not known or password incorrect") return False return True togetherai_client = OpenAI( api_key=os.getenv('TOGETHERAI_API_KEY'), base_url="https://api.together.xyz/v1" ) openai_client = OpenAI( api_key=os.getenv('OPENAI_API_KEY') ) anthropic_client = Anthropic( api_key=os.getenv('ANTHROPIC_API_KEY') ) genai.configure(api_key=os.environ["GEMINI_API_KEY"]) MAX_CONCURRENT_CALLS = 5 semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS) @st.cache_data def load_dataset_by_name(dataset_name, split="train"): dataset_config = DATASETS[dataset_name] dataset = load_dataset(dataset_config["loader"]) df = pd.DataFrame(dataset[split]) df = df[df['choice_type'] == 'single'] questions = [] for _, row in df.iterrows(): options = [row['opa'], row['opb'], row['opc'], row['opd']] correct_answer = options[row['cop']] question_dict = { 'question': row['question'], 'options': options, 'correct_answer': correct_answer, 'subject_name': row['subject_name'], 'topic_name': row['topic_name'], 'explanation': row['exp'] } questions.append(question_dict) st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}") return questions @retry( wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(5), retry=retry_if_exception_type(Exception) ) def get_model_response(question, options, prompt_template, model_name): with semaphore: try: model_config = MODELS[model_name] options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)]) prompt = prompt_template.replace("{question}", question).replace("{options}", options_text) provider = model_config["provider"] if provider == "togetherai": response = togetherai_client.chat.completions.create( model=model_config["model_id"], messages=[{"role": "user", "content": prompt}] ) response_text = response.choices[0].message.content.strip() elif provider == "openai": response = openai_client.chat.completions.create( model=model_config["model_id"], messages=[{ "role": "user", "content": prompt}] ) response_text = response.choices[0].message.content.strip() elif provider == "anthropic": response = anthropic_client.messages.create( model=model_config["model_id"], messages=[{"role": "user", "content": prompt}], max_tokens=4096 ) response_text = response.content[0].text elif provider == "google": model = genai.GenerativeModel( model_name=model_config["model_id"] ) chat_session = model.start_chat( history=[] ) response_text = chat_session.send_message(prompt).text json_match = re.search(r'\{.*\}', response_text, re.DOTALL) if not json_match: return f"Error: Invalid response format", response_text json_response = json.loads(json_match.group(0)) answer = json_response.get('answer', '').strip() answer = re.sub(r'^[A-D]\.\s*', '', answer) if not any(answer.lower() == opt.lower() for opt in options): return f"Error: Answer '{answer}' does not match any options", response_text return answer, response_text except Exception as e: return f"Error: {str(e)}", str(e) def evaluate_response(model_response, correct_answer): if model_response.startswith("Error:"): return False is_correct = model_response.lower().strip() == correct_answer.lower().strip() return is_correct def process_single_evaluation(question, prompt_template, model_name): answer, response_text = get_model_response( question['question'], question['options'], prompt_template, model_name ) is_correct = evaluate_response(answer, question['correct_answer']) return { 'question': question['question'], 'options': question['options'], 'model_response': answer, 'raw_llm_response': response_text, 'prompt_sent': prompt_template, 'correct_answer': question['correct_answer'], 'subject': question['subject_name'], 'is_correct': is_correct, 'explanation': question['explanation'], 'model_name': model_name } def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback): results = [] total_iterations = len(models_to_evaluate) * len(questions) current_iteration = 0 with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor: future_to_params = {} for model_name in models_to_evaluate: for question in questions: future = executor.submit(process_single_evaluation, question, prompt_template, model_name) future_to_params[future] = (model_name, question) for future in as_completed(future_to_params): result = future.result() results.append(result) current_iteration += 1 progress_callback(current_iteration, total_iterations) return results def main(): st.set_page_config(page_title="LLM Benchmarking in Healthcare", layout="wide") if not check_password(): st.stop() st.title("LLM Benchmarking in Healthcare") if 'all_results' not in st.session_state: st.session_state.all_results = {} if 'detailed_model' not in st.session_state: st.session_state.detailed_model = None if 'detailed_dataset' not in st.session_state: st.session_state.detailed_dataset = None if 'last_evaluated_dataset' not in st.session_state: st.session_state.last_evaluated_dataset = None col1, col2 = st.columns(2) with col1: selected_dataset = st.selectbox( "Select Dataset", options=list(DATASETS.keys()), help="Choose the dataset to evaluate on" ) with col2: selected_model = st.multiselect( "Select Model(s)", options=list(MODELS.keys()), default=[list(MODELS.keys())[0]], help="Choose one or more models to evaluate." ) models_to_evaluate = selected_model default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question. Question: {question} Options: {options} ## Output Format: Please provide your answer in JSON format that contains an "answer" field. You may include any additional fields in your JSON response that you find relevant, such as: - "choice reasoning": your detailed reasoning - "elimination reasoning": why you ruled out other options Example response format: { "answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)", "choice reasoning": "your detailed reasoning here", "elimination reasoning": "why you ruled out other options" } Important: - Only the "answer" field will be used for evaluation - Ensure your response is in valid JSON format''' col1, col2 = st.columns([2, 1]) with col1: prompt_template = st.text_area( "Customize Prompt Template", default_prompt, height=400, help="The below prompt is editable. Please feel free to edit it before your run." ) with col2: st.markdown(""" ### Prompt Variables - `{question}`: The medical question - `{options}`: The multiple choice options """) with st.spinner("Loading dataset..."): questions = load_dataset_by_name(selected_dataset) subjects = sorted(list(set(q['subject_name'] for q in questions))) selected_subject = st.selectbox("Filter by subject", ["All"] + subjects) if selected_subject != "All": questions = [q for q in questions if q['subject_name'] == selected_subject] num_questions = st.number_input("Number of questions to evaluate", 1, len(questions)) if st.button("Start Evaluation"): with st.spinner("Starting evaluation..."): selected_questions = questions[:num_questions] progress_container = st.container() progress_bar = progress_container.progress(0) status_text = progress_container.empty() def update_progress(current, total): progress = current / total progress_bar.progress(progress) status_text.text(f"Progress: {current}/{total} evaluations completed") results = process_evaluations_concurrently( selected_questions, prompt_template, models_to_evaluate, update_progress ) all_results = {} for result in results: model = result.pop('model_name') if model not in all_results: all_results[model] = [] all_results[model].append(result) st.session_state.all_results = all_results st.session_state.last_evaluated_dataset = selected_dataset if st.session_state.detailed_model is None and all_results: st.session_state.detailed_model = list(all_results.keys())[0] if st.session_state.detailed_dataset is None: st.session_state.detailed_dataset = selected_dataset st.success("Evaluation completed!") st.rerun() if st.session_state.all_results: st.subheader("Evaluation Results") model_metrics = {} for model_name, results in st.session_state.all_results.items(): df = pd.DataFrame(results) metrics = { 'Accuracy': df['is_correct'].mean(), } model_metrics[model_name] = metrics metrics_df = pd.DataFrame(model_metrics).T st.subheader("Model Performance Comparison") accuracy_chart = alt.Chart( metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy']) ).mark_bar().encode( x=alt.X('index:N', title=None, axis=None), y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])), color=alt.Color('index:N', scale=alt.Scale(scheme='blues')), tooltip=['index:N', 'value:Q'] ).properties( height=300, title={ "text": "Model Accuracy", "baseline": "bottom", "orient": "bottom", "dy": 20 } ) st.altair_chart(accuracy_chart, use_container_width=True) if st.session_state.all_results: st.subheader("Detailed Results") def update_model(): st.session_state.detailed_model = st.session_state.model_select def update_dataset(): st.session_state.detailed_dataset = st.session_state.dataset_select col1, col2 = st.columns(2) with col1: selected_model_details = st.selectbox( "Select model", options=list(st.session_state.all_results.keys()), key="model_select", on_change=update_model, index=list(st.session_state.all_results.keys()).index(st.session_state.detailed_model) if st.session_state.detailed_model in st.session_state.all_results else 0 ) with col2: selected_dataset_details = st.selectbox( "Select dataset", options=[st.session_state.last_evaluated_dataset], key="dataset_select", on_change=update_dataset ) if selected_model_details in st.session_state.all_results: results = st.session_state.all_results[selected_model_details] df = pd.DataFrame(results) accuracy = df['is_correct'].mean() st.metric("Accuracy", f"{accuracy:.2%}") for idx, result in enumerate(results): with st.expander(f"Question {idx + 1} - {result['subject']}"): st.write("**Question:**", result['question']) st.write("**Options:**") for i, opt in enumerate(result['options']): st.write(f"{chr(65+i)}. {opt}") col1, col2 = st.columns(2) with col1: st.write("**Prompt Used:**") st.code(result['prompt_sent']) with col2: st.write("**Raw Response:**") st.code(result['raw_llm_response']) col1, col2 = st.columns(2) with col1: st.write("**Correct Answer:**", result['correct_answer']) st.write("**Model Answer:**", result['model_response']) with col2: if result['is_correct']: st.success("Correct!") else: st.error("Incorrect") st.write("**Explanation:**", result['explanation']) else: st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.") st.markdown("---") all_data = [] for model_name, results in st.session_state.all_results.items(): for question_idx, result in enumerate(results): row = { 'dataset': st.session_state.last_evaluated_dataset, 'model': model_name, 'question': result['question'], 'correct_answer': result['correct_answer'], 'subject': result['subject'], 'options': ' | '.join(result['options']), 'model_response': result['model_response'], 'is_correct': result['is_correct'], 'explanation': result['explanation'] } all_data.append(row) complete_df = pd.DataFrame(all_data) csv = complete_df.to_csv(index=False) st.download_button( label="Download All Results as CSV", data=csv, file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv", mime="text/csv", key="download_all_results" ) if __name__ == "__main__": main()