ehagey commited on
Commit
9c56e37
·
verified ·
1 Parent(s): 2020175

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -0
app.py CHANGED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from together import Together
4
+ from dotenv import load_dotenv
5
+ from datasets import load_dataset
6
+ import json
7
+ import re
8
+ import os
9
+ from config import DATASETS, MODELS
10
+
11
+ load_dotenv()
12
+ client = Together(api_key=os.getenv('TOGETHERAI_API_KEY'))
13
+
14
+ @st.cache_data
15
+ def load_dataset_by_name(dataset_name, split="train"):
16
+ dataset_config = DATASETS[dataset_name]
17
+ dataset = load_dataset(dataset_config["loader"])
18
+ df = pd.DataFrame(dataset[split])
19
+ df = df[df['choice_type'] == 'single']
20
+
21
+ questions = []
22
+ for _, row in df.iterrows():
23
+ options = [row['opa'], row['opb'], row['opc'], row['opd']]
24
+ correct_answer = options[row['cop']]
25
+
26
+ question_dict = {
27
+ 'question': row['question'],
28
+ 'options': options,
29
+ 'correct_answer': correct_answer,
30
+ 'subject_name': row['subject_name'],
31
+ 'topic_name': row['topic_name'],
32
+ 'explanation': row['exp']
33
+ }
34
+ questions.append(question_dict)
35
+
36
+ st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}")
37
+ return questions
38
+
39
+ def get_model_response(question, options, prompt_template, model_name):
40
+ try:
41
+ model_config = MODELS[model_name]
42
+ options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
43
+ prompt = prompt_template.replace("{question}", question).replace("{options}", options_text)
44
+
45
+ response = client.chat.completions.create(
46
+ model=model_config["model_id"],
47
+ messages=[{"role": "user", "content": prompt}]
48
+ )
49
+
50
+ response_text = response.choices[0].message.content.strip()
51
+ json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
52
+ json_response = json.loads(json_match.group(0))
53
+ answer = json_response['answer'].strip()
54
+ answer = re.sub(r'^[A-D]\.\s*', '', answer)
55
+
56
+ if not any(answer.lower() == opt.lower() for opt in options):
57
+ return f"Error: Answer '{answer}' does not match any options"
58
+
59
+ return answer
60
+
61
+ except Exception as e:
62
+ return f"Error: {str(e)}"
63
+
64
+ def evaluate_response(model_response, correct_answer):
65
+ if model_response.startswith("Error:"):
66
+ return False
67
+ return model_response.lower().strip() == correct_answer.lower().strip()
68
+
69
+ def main():
70
+ st.set_page_config(page_title="Medical LLM Evaluation", layout="wide")
71
+ st.title("Medical LLM Evaluation")
72
+
73
+ col1, col2 = st.columns(2)
74
+ with col1:
75
+ selected_dataset = st.selectbox(
76
+ "Select Dataset",
77
+ options=list(DATASETS.keys()),
78
+ help="Choose the dataset to evaluate on"
79
+ )
80
+ with col2:
81
+ selected_model = st.selectbox(
82
+ "Select Model",
83
+ options=list(MODELS.keys()),
84
+ help="Choose the model to evaluate"
85
+ )
86
+
87
+ default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
88
+
89
+ Question: {question}
90
+
91
+ Options:
92
+ {options}
93
+
94
+ ## Output Format:
95
+ Please provide you answer in JSON format that contains an "answer" field.
96
+ You may include any additional fields in your JSON response that you find relevant, such as:
97
+ - "answer": the option you selected
98
+ - "choice reasoning": your detailed reasoning
99
+ - "elimination reasoning": why you ruled out other options
100
+
101
+ Example response format:
102
+ {
103
+ "answer": "exact option text here",
104
+ "choice reasoning": "your detailed reasoning here",
105
+ "elimination reasoning": "why you ruled out other options"
106
+ }
107
+
108
+ Important:
109
+ - Only the "answer" field will be used for evaluation
110
+ - Ensure your response is in valid JSON format'''
111
+
112
+ col1, col2 = st.columns([2, 1])
113
+ with col1:
114
+ prompt_template = st.text_area(
115
+ "Customize Prompt Template",
116
+ default_prompt,
117
+ height=400,
118
+ help="The below prompt is editable. Please feel free to edit it before your run."
119
+ )
120
+
121
+ with col2:
122
+ st.markdown("""
123
+ ### Prompt Variables
124
+ - `{question}`: The medical question
125
+ - `{options}`: The multiple choice options
126
+ """)
127
+
128
+ with st.spinner("Loading dataset..."):
129
+ questions = load_dataset_by_name(selected_dataset)
130
+
131
+ if not questions:
132
+ st.error("No questions were loaded successfully.")
133
+ return
134
+
135
+ subjects = list(set(q['subject_name'] for q in questions))
136
+ selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
137
+
138
+ if selected_subject != "All":
139
+ questions = [q for q in questions if q['subject_name'] == selected_subject]
140
+
141
+ num_questions = st.number_input("Number of questions to evaluate", 1, len(questions))
142
+
143
+ if st.button("Start Evaluation"):
144
+ if not os.getenv('TOGETHERAI_API_KEY'):
145
+ st.error("Please set the TOGETHERAI_API_KEY in your .env file")
146
+ return
147
+
148
+ progress_bar = st.progress(0)
149
+ status_text = st.empty()
150
+ results_container = st.container()
151
+
152
+ results = []
153
+ for i in range(num_questions):
154
+ question = questions[i]
155
+ progress = (i + 1) / num_questions
156
+ progress_bar.progress(progress)
157
+ status_text.text(f"Evaluating question {i + 1}/{num_questions}")
158
+
159
+ model_response = get_model_response(
160
+ question['question'],
161
+ question['options'],
162
+ prompt_template,
163
+ selected_model
164
+ )
165
+
166
+ options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(question['options'])])
167
+ formatted_prompt = prompt_template.replace("{question}", question['question']).replace("{options}", options_text)
168
+ raw_response = client.chat.completions.create(
169
+ model=MODELS[selected_model]["model_id"],
170
+ messages=[{"role": "user", "content": formatted_prompt}]
171
+ ).choices[0].message.content.strip()
172
+
173
+ is_correct = evaluate_response(model_response, question['correct_answer'])
174
+
175
+ results.append({
176
+ 'question': question['question'],
177
+ 'options': question['options'],
178
+ 'model_response': model_response,
179
+ 'raw_llm_response': raw_response,
180
+ 'prompt_sent': formatted_prompt,
181
+ 'correct_answer': question['correct_answer'],
182
+ 'subject': question['subject_name'],
183
+ 'is_correct': is_correct,
184
+ 'explanation': question['explanation']
185
+ })
186
+
187
+ with results_container:
188
+ st.subheader("Evaluation Results")
189
+ df = pd.DataFrame(results)
190
+ accuracy = df['is_correct'].mean()
191
+ st.metric("Accuracy", f"{accuracy:.2%}")
192
+
193
+ for idx, result in enumerate(results):
194
+ st.markdown("---")
195
+ st.subheader(f"Question {idx + 1} - {result['subject']}")
196
+
197
+ st.write("Question:", result['question'])
198
+ st.write("Options:")
199
+ for i, opt in enumerate(result['options']):
200
+ st.write(f"{chr(65+i)}. {opt}")
201
+
202
+ col1, col2 = st.columns(2)
203
+ with col1:
204
+ with st.expander("Show Prompt"):
205
+ st.code(result['prompt_sent'])
206
+ with col2:
207
+ with st.expander("Show Raw Response"):
208
+ st.code(result['raw_llm_response'])
209
+
210
+ col1, col2 = st.columns(2)
211
+ with col1:
212
+ st.write("Correct Answer:", result['correct_answer'])
213
+ st.write("Model Answer:", result['model_response'])
214
+ with col2:
215
+ if result['is_correct']:
216
+ st.success("Correct!")
217
+ else:
218
+ st.error("Incorrect")
219
+
220
+ with st.expander("Show Explanation"):
221
+ st.write(result['explanation'])
222
+
223
+ if __name__ == "__main__":
224
+ main()