ehagey commited on
Commit
68ecdd8
·
verified ·
1 Parent(s): 221a11c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -13,6 +13,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
13
  from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
14
  import threading
15
  from anthropic import Anthropic
 
 
 
16
 
17
  load_dotenv()
18
 
@@ -29,6 +32,7 @@ anthropic_client = Anthropic(
29
  api_key=os.getenv('ANTHROPIC_API_KEY')
30
  )
31
 
 
32
 
33
  MAX_CONCURRENT_CALLS = 5
34
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
@@ -97,6 +101,16 @@ def get_model_response(question, options, prompt_template, model_name):
97
  )
98
  response_text = response.content[0].text
99
 
 
 
 
 
 
 
 
 
 
 
100
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
101
  if not json_match:
102
  return f"Error: Invalid response format", response_text
@@ -296,7 +310,7 @@ Important:
296
  ).mark_bar().encode(
297
  x=alt.X('index:N', title=None, axis=None),
298
  y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
299
- color='index:N',
300
  tooltip=['index:N', 'value:Q']
301
  ).properties(
302
  height=300,
@@ -307,6 +321,7 @@ Important:
307
  "dy": 20
308
  }
309
  )
 
310
  st.altair_chart(accuracy_chart, use_container_width=True)
311
 
312
  if st.session_state.all_results:
 
13
  from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
14
  import threading
15
  from anthropic import Anthropic
16
+ import google.generativeai as genai
17
+
18
+
19
 
20
  load_dotenv()
21
 
 
32
  api_key=os.getenv('ANTHROPIC_API_KEY')
33
  )
34
 
35
+ genai.configure(api_key=os.environ["GEMINI_API_KEY"])
36
 
37
  MAX_CONCURRENT_CALLS = 5
38
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
 
101
  )
102
  response_text = response.content[0].text
103
 
104
+ elif provider == "google":
105
+ model = genai.GenerativeModel(
106
+ model_name=model_config["model_id"]
107
+ )
108
+
109
+ chat_session = model.start_chat(
110
+ history=[]
111
+ )
112
+ response_text = chat_session.send_message(prompt).text
113
+
114
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
115
  if not json_match:
116
  return f"Error: Invalid response format", response_text
 
310
  ).mark_bar().encode(
311
  x=alt.X('index:N', title=None, axis=None),
312
  y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
313
+ color=alt.Color('index:N', scale=alt.Scale(scheme='blues')),
314
  tooltip=['index:N', 'value:Q']
315
  ).properties(
316
  height=300,
 
321
  "dy": 20
322
  }
323
  )
324
+
325
  st.altair_chart(accuracy_chart, use_container_width=True)
326
 
327
  if st.session_state.all_results: