Joschka Strueber commited on
Commit
36159b1
·
1 Parent(s): 465a95b

[Add, Fix] fix clearing model list, improve axes labels

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -13,6 +13,10 @@ def create_heatmap(selected_models, selected_dataset):
13
  if not selected_models or not selected_dataset:
14
  return None
15
 
 
 
 
 
16
  size = len(selected_models)
17
  similarities = np.random.rand(size, size)
18
  similarities = (similarities + similarities.T) / 2
@@ -27,14 +31,14 @@ def create_heatmap(selected_models, selected_dataset):
27
  cmap="viridis",
28
  vmin=0,
29
  vmax=1,
30
- xticklabels=selected_models,
31
- yticklabels=selected_models
32
  )
33
 
34
  # Customize plot
35
- plt.title(f"Similarity Matrix for {selected_dataset}", fontsize=14)
36
- plt.xlabel("Models")
37
- plt.ylabel("Models")
38
  plt.xticks(rotation=45, ha='right')
39
  plt.yticks(rotation=0)
40
  plt.tight_layout()
@@ -90,7 +94,7 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
90
 
91
  clear_btn = gr.Button("Clear Selection")
92
  clear_btn.click(
93
- lambda: [None, None, None],
94
  outputs=[model_dropdown, dataset_dropdown, heatmap]
95
  )
96
 
 
13
  if not selected_models or not selected_dataset:
14
  return None
15
 
16
+ selected_models = selected_models.sort()
17
+ selected_models_short = [model.split("/")[-1] for model in selected_models]
18
+
19
+ # Generate random similarity matrix
20
  size = len(selected_models)
21
  similarities = np.random.rand(size, size)
22
  similarities = (similarities + similarities.T) / 2
 
31
  cmap="viridis",
32
  vmin=0,
33
  vmax=1,
34
+ xticklabels=selected_models_short,
35
+ yticklabels=selected_models_short
36
  )
37
 
38
  # Customize plot
39
+ plt.title(f"Similarity Matrix for {selected_dataset}", fontsize=16)
40
+ plt.xlabel("Models", fontsize=14)
41
+ plt.ylabel("Models", fontsize=14)
42
  plt.xticks(rotation=45, ha='right')
43
  plt.yticks(rotation=0)
44
  plt.tight_layout()
 
94
 
95
  clear_btn = gr.Button("Clear Selection")
96
  clear_btn.click(
97
+ lambda: [[], None, None],
98
  outputs=[model_dropdown, dataset_dropdown, heatmap]
99
  )
100