Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from together import Together
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from datasets import load_dataset
|
6 |
+
import json
|
7 |
+
import re
|
8 |
+
import os
|
9 |
+
from config import DATASETS, MODELS
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
client = Together(api_key=os.getenv('TOGETHERAI_API_KEY'))
|
13 |
+
|
14 |
+
@st.cache_data
|
15 |
+
def load_dataset_by_name(dataset_name, split="train"):
|
16 |
+
dataset_config = DATASETS[dataset_name]
|
17 |
+
dataset = load_dataset(dataset_config["loader"])
|
18 |
+
df = pd.DataFrame(dataset[split])
|
19 |
+
df = df[df['choice_type'] == 'single']
|
20 |
+
|
21 |
+
questions = []
|
22 |
+
for _, row in df.iterrows():
|
23 |
+
options = [row['opa'], row['opb'], row['opc'], row['opd']]
|
24 |
+
correct_answer = options[row['cop']]
|
25 |
+
|
26 |
+
question_dict = {
|
27 |
+
'question': row['question'],
|
28 |
+
'options': options,
|
29 |
+
'correct_answer': correct_answer,
|
30 |
+
'subject_name': row['subject_name'],
|
31 |
+
'topic_name': row['topic_name'],
|
32 |
+
'explanation': row['exp']
|
33 |
+
}
|
34 |
+
questions.append(question_dict)
|
35 |
+
|
36 |
+
st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}")
|
37 |
+
return questions
|
38 |
+
|
39 |
+
def get_model_response(question, options, prompt_template, model_name):
|
40 |
+
try:
|
41 |
+
model_config = MODELS[model_name]
|
42 |
+
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
|
43 |
+
prompt = prompt_template.replace("{question}", question).replace("{options}", options_text)
|
44 |
+
|
45 |
+
response = client.chat.completions.create(
|
46 |
+
model=model_config["model_id"],
|
47 |
+
messages=[{"role": "user", "content": prompt}]
|
48 |
+
)
|
49 |
+
|
50 |
+
response_text = response.choices[0].message.content.strip()
|
51 |
+
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
52 |
+
json_response = json.loads(json_match.group(0))
|
53 |
+
answer = json_response['answer'].strip()
|
54 |
+
answer = re.sub(r'^[A-D]\.\s*', '', answer)
|
55 |
+
|
56 |
+
if not any(answer.lower() == opt.lower() for opt in options):
|
57 |
+
return f"Error: Answer '{answer}' does not match any options"
|
58 |
+
|
59 |
+
return answer
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
return f"Error: {str(e)}"
|
63 |
+
|
64 |
+
def evaluate_response(model_response, correct_answer):
|
65 |
+
if model_response.startswith("Error:"):
|
66 |
+
return False
|
67 |
+
return model_response.lower().strip() == correct_answer.lower().strip()
|
68 |
+
|
69 |
+
def main():
|
70 |
+
st.set_page_config(page_title="Medical LLM Evaluation", layout="wide")
|
71 |
+
st.title("Medical LLM Evaluation")
|
72 |
+
|
73 |
+
col1, col2 = st.columns(2)
|
74 |
+
with col1:
|
75 |
+
selected_dataset = st.selectbox(
|
76 |
+
"Select Dataset",
|
77 |
+
options=list(DATASETS.keys()),
|
78 |
+
help="Choose the dataset to evaluate on"
|
79 |
+
)
|
80 |
+
with col2:
|
81 |
+
selected_model = st.selectbox(
|
82 |
+
"Select Model",
|
83 |
+
options=list(MODELS.keys()),
|
84 |
+
help="Choose the model to evaluate"
|
85 |
+
)
|
86 |
+
|
87 |
+
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
|
88 |
+
|
89 |
+
Question: {question}
|
90 |
+
|
91 |
+
Options:
|
92 |
+
{options}
|
93 |
+
|
94 |
+
## Output Format:
|
95 |
+
Please provide you answer in JSON format that contains an "answer" field.
|
96 |
+
You may include any additional fields in your JSON response that you find relevant, such as:
|
97 |
+
- "answer": the option you selected
|
98 |
+
- "choice reasoning": your detailed reasoning
|
99 |
+
- "elimination reasoning": why you ruled out other options
|
100 |
+
|
101 |
+
Example response format:
|
102 |
+
{
|
103 |
+
"answer": "exact option text here",
|
104 |
+
"choice reasoning": "your detailed reasoning here",
|
105 |
+
"elimination reasoning": "why you ruled out other options"
|
106 |
+
}
|
107 |
+
|
108 |
+
Important:
|
109 |
+
- Only the "answer" field will be used for evaluation
|
110 |
+
- Ensure your response is in valid JSON format'''
|
111 |
+
|
112 |
+
col1, col2 = st.columns([2, 1])
|
113 |
+
with col1:
|
114 |
+
prompt_template = st.text_area(
|
115 |
+
"Customize Prompt Template",
|
116 |
+
default_prompt,
|
117 |
+
height=400,
|
118 |
+
help="The below prompt is editable. Please feel free to edit it before your run."
|
119 |
+
)
|
120 |
+
|
121 |
+
with col2:
|
122 |
+
st.markdown("""
|
123 |
+
### Prompt Variables
|
124 |
+
- `{question}`: The medical question
|
125 |
+
- `{options}`: The multiple choice options
|
126 |
+
""")
|
127 |
+
|
128 |
+
with st.spinner("Loading dataset..."):
|
129 |
+
questions = load_dataset_by_name(selected_dataset)
|
130 |
+
|
131 |
+
if not questions:
|
132 |
+
st.error("No questions were loaded successfully.")
|
133 |
+
return
|
134 |
+
|
135 |
+
subjects = list(set(q['subject_name'] for q in questions))
|
136 |
+
selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
|
137 |
+
|
138 |
+
if selected_subject != "All":
|
139 |
+
questions = [q for q in questions if q['subject_name'] == selected_subject]
|
140 |
+
|
141 |
+
num_questions = st.number_input("Number of questions to evaluate", 1, len(questions))
|
142 |
+
|
143 |
+
if st.button("Start Evaluation"):
|
144 |
+
if not os.getenv('TOGETHERAI_API_KEY'):
|
145 |
+
st.error("Please set the TOGETHERAI_API_KEY in your .env file")
|
146 |
+
return
|
147 |
+
|
148 |
+
progress_bar = st.progress(0)
|
149 |
+
status_text = st.empty()
|
150 |
+
results_container = st.container()
|
151 |
+
|
152 |
+
results = []
|
153 |
+
for i in range(num_questions):
|
154 |
+
question = questions[i]
|
155 |
+
progress = (i + 1) / num_questions
|
156 |
+
progress_bar.progress(progress)
|
157 |
+
status_text.text(f"Evaluating question {i + 1}/{num_questions}")
|
158 |
+
|
159 |
+
model_response = get_model_response(
|
160 |
+
question['question'],
|
161 |
+
question['options'],
|
162 |
+
prompt_template,
|
163 |
+
selected_model
|
164 |
+
)
|
165 |
+
|
166 |
+
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(question['options'])])
|
167 |
+
formatted_prompt = prompt_template.replace("{question}", question['question']).replace("{options}", options_text)
|
168 |
+
raw_response = client.chat.completions.create(
|
169 |
+
model=MODELS[selected_model]["model_id"],
|
170 |
+
messages=[{"role": "user", "content": formatted_prompt}]
|
171 |
+
).choices[0].message.content.strip()
|
172 |
+
|
173 |
+
is_correct = evaluate_response(model_response, question['correct_answer'])
|
174 |
+
|
175 |
+
results.append({
|
176 |
+
'question': question['question'],
|
177 |
+
'options': question['options'],
|
178 |
+
'model_response': model_response,
|
179 |
+
'raw_llm_response': raw_response,
|
180 |
+
'prompt_sent': formatted_prompt,
|
181 |
+
'correct_answer': question['correct_answer'],
|
182 |
+
'subject': question['subject_name'],
|
183 |
+
'is_correct': is_correct,
|
184 |
+
'explanation': question['explanation']
|
185 |
+
})
|
186 |
+
|
187 |
+
with results_container:
|
188 |
+
st.subheader("Evaluation Results")
|
189 |
+
df = pd.DataFrame(results)
|
190 |
+
accuracy = df['is_correct'].mean()
|
191 |
+
st.metric("Accuracy", f"{accuracy:.2%}")
|
192 |
+
|
193 |
+
for idx, result in enumerate(results):
|
194 |
+
st.markdown("---")
|
195 |
+
st.subheader(f"Question {idx + 1} - {result['subject']}")
|
196 |
+
|
197 |
+
st.write("Question:", result['question'])
|
198 |
+
st.write("Options:")
|
199 |
+
for i, opt in enumerate(result['options']):
|
200 |
+
st.write(f"{chr(65+i)}. {opt}")
|
201 |
+
|
202 |
+
col1, col2 = st.columns(2)
|
203 |
+
with col1:
|
204 |
+
with st.expander("Show Prompt"):
|
205 |
+
st.code(result['prompt_sent'])
|
206 |
+
with col2:
|
207 |
+
with st.expander("Show Raw Response"):
|
208 |
+
st.code(result['raw_llm_response'])
|
209 |
+
|
210 |
+
col1, col2 = st.columns(2)
|
211 |
+
with col1:
|
212 |
+
st.write("Correct Answer:", result['correct_answer'])
|
213 |
+
st.write("Model Answer:", result['model_response'])
|
214 |
+
with col2:
|
215 |
+
if result['is_correct']:
|
216 |
+
st.success("Correct!")
|
217 |
+
else:
|
218 |
+
st.error("Incorrect")
|
219 |
+
|
220 |
+
with st.expander("Show Explanation"):
|
221 |
+
st.write(result['explanation'])
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
main()
|