ehagey commited on
Commit
d55da55
·
verified ·
1 Parent(s): 84fa20a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -11
app.py CHANGED
@@ -9,18 +9,27 @@ import os
9
  from config import DATASETS, MODELS
10
  import matplotlib.pyplot as plt
11
  import altair as alt
12
- import logging
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
14
  from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
15
  import threading
 
16
 
17
  load_dotenv()
18
 
19
- client = OpenAI(
20
  api_key=os.getenv('TOGETHERAI_API_KEY'),
21
  base_url="https://api.together.xyz/v1"
22
  )
23
 
 
 
 
 
 
 
 
 
 
24
  MAX_CONCURRENT_CALLS = 5
25
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
26
 
@@ -54,12 +63,6 @@ def load_dataset_by_name(dataset_name, split="train"):
54
  stop=stop_after_attempt(5),
55
  retry=retry_if_exception_type(Exception)
56
  )
57
- def fetch_model_response(prompt, model_id):
58
- response = client.chat.completions.create(
59
- model=model_id,
60
- messages=[{"role": "user", "content": prompt}]
61
- )
62
- return response
63
 
64
  def get_model_response(question, options, prompt_template, model_name):
65
  with semaphore:
@@ -67,9 +70,33 @@ def get_model_response(question, options, prompt_template, model_name):
67
  model_config = MODELS[model_name]
68
  options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
69
  prompt = prompt_template.replace("{question}", question).replace("{options}", options_text)
70
-
71
- response = fetch_model_response(prompt, model_config["model_id"])
72
- response_text = response.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
74
  if not json_match:
75
  return f"Error: Invalid response format", response_text
 
9
  from config import DATASETS, MODELS
10
  import matplotlib.pyplot as plt
11
  import altair as alt
 
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
13
  from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
14
  import threading
15
+ from anthropic import Anthropic
16
 
17
  load_dotenv()
18
 
19
+ togetherai_client = OpenAI(
20
  api_key=os.getenv('TOGETHERAI_API_KEY'),
21
  base_url="https://api.together.xyz/v1"
22
  )
23
 
24
+ openai_client = OpenAI(
25
+ api_key=os.getenv('OPENAI_API_KEY')
26
+ )
27
+
28
+ anthropic_client = Anthropic(
29
+ api_key=os.getenv('ANTHROPIC_API_KEY')
30
+ )
31
+
32
+
33
  MAX_CONCURRENT_CALLS = 5
34
  semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)
35
 
 
63
  stop=stop_after_attempt(5),
64
  retry=retry_if_exception_type(Exception)
65
  )
 
 
 
 
 
 
66
 
67
  def get_model_response(question, options, prompt_template, model_name):
68
  with semaphore:
 
70
  model_config = MODELS[model_name]
71
  options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
72
  prompt = prompt_template.replace("{question}", question).replace("{options}", options_text)
73
+
74
+ provider = model_config["provider"]
75
+
76
+ if provider == "togetherai":
77
+ response = togetherai_client.chat.completions.create(
78
+ model=model_config["model_id"],
79
+ messages=[{"role": "user", "content": prompt}]
80
+ )
81
+ response_text = response.choices[0].message.content.strip()
82
+
83
+ elif provider == "openai":
84
+ response = openai_client.chat.completions.create(
85
+ model=model_config["model_id"],
86
+ messages=[{
87
+ "role": "user",
88
+ "content": prompt}]
89
+ )
90
+ response_text = response.choices[0].message.content.strip()
91
+
92
+ elif provider == "anthropic":
93
+ response = anthropic_client.messages.create(
94
+ model=model_config["model_id"],
95
+ messages=[{"role": "user", "content": prompt}],
96
+ max_tokens=4096
97
+ )
98
+ response_text = response.content[0].text
99
+
100
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
101
  if not json_match:
102
  return f"Error: Invalid response format", response_text