Spaces:
Running
Running
Joschka Strueber
commited on
Commit
·
75132dc
1
Parent(s):
b1f98e1
[Fix] error in deleting not-matching gt values
Browse files- src/similarity.py +7 -10
src/similarity.py
CHANGED
@@ -35,7 +35,7 @@ def compute_pairwise_similarities(metric_name: str, probs: list[list[np.array]],
|
|
35 |
metric = CAPA()
|
36 |
elif metric_name == "CAPA (det.)":
|
37 |
metric = CAPA(prob=False)
|
38 |
-
# Convert
|
39 |
probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
|
40 |
elif metric_name == "Error Consistency":
|
41 |
probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
|
@@ -48,25 +48,22 @@ def compute_pairwise_similarities(metric_name: str, probs: list[list[np.array]],
|
|
48 |
for j in range(i, len(probs)):
|
49 |
outputs_a = probs[i]
|
50 |
outputs_b = probs[j]
|
51 |
-
gt_a = gts[i]
|
52 |
-
gt_b = gts[j]
|
53 |
|
54 |
# Format softmax outputs
|
55 |
if metric_name == "CAPA":
|
56 |
outputs_a = [softmax(logits) for logits in outputs_a]
|
57 |
outputs_b = [softmax(logits) for logits in outputs_b]
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
for idx, (a, b) in enumerate(zip(gt_a, gt_b)):
|
63 |
-
if a != b:
|
64 |
-
indices_to_remove.append(idx)
|
65 |
for idx in sorted(indices_to_remove, reverse=True):
|
66 |
del outputs_a[idx]
|
67 |
del outputs_b[idx]
|
68 |
del gt_a[idx]
|
69 |
-
del gt_b[idx]
|
70 |
|
71 |
try:
|
72 |
similarities[i, j] = compute_similarity(metric, outputs_a, outputs_b, gt_a)
|
|
|
35 |
metric = CAPA()
|
36 |
elif metric_name == "CAPA (det.)":
|
37 |
metric = CAPA(prob=False)
|
38 |
+
# Convert logits to one-hot
|
39 |
probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
|
40 |
elif metric_name == "Error Consistency":
|
41 |
probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
|
|
|
48 |
for j in range(i, len(probs)):
|
49 |
outputs_a = probs[i]
|
50 |
outputs_b = probs[j]
|
51 |
+
gt_a = gts[i].copy()
|
52 |
+
gt_b = gts[j].copy()
|
53 |
|
54 |
# Format softmax outputs
|
55 |
if metric_name == "CAPA":
|
56 |
outputs_a = [softmax(logits) for logits in outputs_a]
|
57 |
outputs_b = [softmax(logits) for logits in outputs_b]
|
58 |
|
59 |
+
# Remove indices where the ground truth differs
|
60 |
+
# (This code assumes gt_a and gt_b are lists of integers.)
|
61 |
+
indices_to_remove = [idx for idx, (a, b) in enumerate(zip(gt_a, gt_b)) if a != b]
|
|
|
|
|
|
|
62 |
for idx in sorted(indices_to_remove, reverse=True):
|
63 |
del outputs_a[idx]
|
64 |
del outputs_b[idx]
|
65 |
del gt_a[idx]
|
66 |
+
del gt_b[idx]
|
67 |
|
68 |
try:
|
69 |
similarities[i, j] = compute_similarity(metric, outputs_a, outputs_b, gt_a)
|