Joschka Strueber commited on
Commit
93d753c
·
1 Parent(s): 1e010df

[Add, Fix] add loading mechanism for cached models, change error to warning when computing heatmap

Browse files
Files changed (2) hide show
  1. app.py +5 -3
  2. src/dataloading.py +20 -13
app.py CHANGED
@@ -29,7 +29,7 @@ def create_heatmap(selected_models, selected_dataset, selected_metric):
29
  failed_models.append(selected_models[i])
30
 
31
  if failed_models:
32
- raise gr.Error(f"Failed to load data for models: {', '.join(failed_models)}")
33
 
34
  # Create figure and heatmap using seaborn
35
  plt.figure(figsize=(8, 6))
@@ -94,6 +94,8 @@ links_markdown = """
94
  [🤗 Data](https://huggingface.co/datasets/bethgelab/lm-similarity)
95
  """
96
 
 
 
97
  # Create Gradio interface
98
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
99
  gr.Markdown("## Model Similarity Comparison Tool")
@@ -101,7 +103,7 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
101
 
102
  with gr.Row():
103
  dataset_dropdown = gr.Dropdown(
104
- choices=get_leaderboard_datasets(None),
105
  label="Select Dataset",
106
  value="mmlu_pro",
107
  filterable=True,
@@ -118,7 +120,7 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
118
  model_dropdown = gr.Dropdown(
119
  choices=get_leaderboard_models_cached(),
120
  label="Select Models",
121
- value=["HuggingFaceTB/SmolLM2-1.7B-Instruct", "tiiuae/Falcon3-7B-Instruct", "google/gemma-2-27b-it", "Qwen/Qwen2.5-72B-Instruct"],
122
  multiselect=True,
123
  filterable=True,
124
  allow_custom_value=False,
 
29
  failed_models.append(selected_models[i])
30
 
31
  if failed_models:
32
+ gr.Warning(f"Failed to load data for models: {'\n'.join(failed_models)}")
33
 
34
  # Create figure and heatmap using seaborn
35
  plt.figure(figsize=(8, 6))
 
94
  [🤗 Data](https://huggingface.co/datasets/bethgelab/lm-similarity)
95
  """
96
 
97
+ model_init = ["HuggingFaceTB/SmolLM2-1.7B-Instruct", "tiiuae/Falcon3-7B-Instruct", "google/gemma-2-27b-it", "Qwen/Qwen2.5-72B-Instruct"]
98
+
99
  # Create Gradio interface
100
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
101
  gr.Markdown("## Model Similarity Comparison Tool")
 
103
 
104
  with gr.Row():
105
  dataset_dropdown = gr.Dropdown(
106
+ choices=get_leaderboard_datasets(model_init),
107
  label="Select Dataset",
108
  value="mmlu_pro",
109
  filterable=True,
 
120
  model_dropdown = gr.Dropdown(
121
  choices=get_leaderboard_models_cached(),
122
  label="Select Models",
123
+ value=model_init,
124
  multiselect=True,
125
  filterable=True,
126
  allow_custom_value=False,
src/dataloading.py CHANGED
@@ -8,6 +8,9 @@ from datasets.exceptions import DatasetNotFoundError
8
 
9
  def get_leaderboard_models():
10
  api = HfApi()
 
 
 
11
 
12
  # List all datasets in the open-llm-leaderboard organization
13
  dataset_list = api.list_datasets(author="open-llm-leaderboard")
@@ -15,19 +18,23 @@ def get_leaderboard_models():
15
  models = []
16
  for dataset in dataset_list:
17
  if dataset.id.endswith("-details"):
18
- dataset_id = dataset.id
19
- try:
20
- # Check if the dataset can be loaded
21
- check_gated = datasets.get_dataset_config_names(dataset_id)
22
- # Format: "open-llm-leaderboard/<provider>__<model_name>-details"
23
- model_part = dataset_id.split("/")[-1].replace("-details", "")
24
- if "__" in model_part:
25
- provider, model = model_part.split("__", 1)
26
- models.append(f"{provider}/{model}")
27
- else:
28
- models.append(model_part)
29
- except Exception as e:
30
- pass
 
 
 
 
31
 
32
  # Save model list as txt file
33
  with open("models.txt", "w") as f:
 
8
 
9
  def get_leaderboard_models():
10
  api = HfApi()
11
+
12
+ # Load prechecked models
13
+ ungated_models = set(line.strip() for line in open("models.txt"))
14
 
15
  # List all datasets in the open-llm-leaderboard organization
16
  dataset_list = api.list_datasets(author="open-llm-leaderboard")
 
18
  models = []
19
  for dataset in dataset_list:
20
  if dataset.id.endswith("-details"):
21
+ # Format: "open-llm-leaderboard/<provider>__<model_name>-details"
22
+ model_part = dataset.id.split("/")[-1].replace("-details", "")
23
+ if "__" in model_part:
24
+ provider, model = model_part.split("__", 1)
25
+ model_name = f"{provider}/{model}"
26
+ else:
27
+ model_name = model_part
28
+
29
+ # Only perform the check if dataset_id is not in the ungated_models list.
30
+ if model_name not in ungated_models:
31
+ try:
32
+ # Check if the dataset can be loaded; if not, skip it.
33
+ datasets.get_dataset_config_names(model_name)
34
+ except Exception as e:
35
+ continue # Skip dataset if an exception occurs
36
+
37
+ models.append(model_name)
38
 
39
  # Save model list as txt file
40
  with open("models.txt", "w") as f: