Joschka Strueber commited on
Commit
75132dc
·
1 Parent(s): b1f98e1

[Fix] error in deleting not-matching gt values

Browse files
Files changed (1) hide show
  1. src/similarity.py +7 -10
src/similarity.py CHANGED
@@ -35,7 +35,7 @@ def compute_pairwise_similarities(metric_name: str, probs: list[list[np.array]],
35
  metric = CAPA()
36
  elif metric_name == "CAPA (det.)":
37
  metric = CAPA(prob=False)
38
- # Convert probabilities to one-hot
39
  probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
40
  elif metric_name == "Error Consistency":
41
  probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
@@ -48,25 +48,22 @@ def compute_pairwise_similarities(metric_name: str, probs: list[list[np.array]],
48
  for j in range(i, len(probs)):
49
  outputs_a = probs[i]
50
  outputs_b = probs[j]
51
- gt_a = gts[i]
52
- gt_b = gts[j]
53
 
54
  # Format softmax outputs
55
  if metric_name == "CAPA":
56
  outputs_a = [softmax(logits) for logits in outputs_a]
57
  outputs_b = [softmax(logits) for logits in outputs_b]
58
 
59
- # Assert that the ground truth index is the same
60
- indices_to_remove = []
61
- if gt_a != gt_b:
62
- for idx, (a, b) in enumerate(zip(gt_a, gt_b)):
63
- if a != b:
64
- indices_to_remove.append(idx)
65
  for idx in sorted(indices_to_remove, reverse=True):
66
  del outputs_a[idx]
67
  del outputs_b[idx]
68
  del gt_a[idx]
69
- del gt_b[idx]
70
 
71
  try:
72
  similarities[i, j] = compute_similarity(metric, outputs_a, outputs_b, gt_a)
 
35
  metric = CAPA()
36
  elif metric_name == "CAPA (det.)":
37
  metric = CAPA(prob=False)
38
+ # Convert logits to one-hot
39
  probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
40
  elif metric_name == "Error Consistency":
41
  probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
 
48
  for j in range(i, len(probs)):
49
  outputs_a = probs[i]
50
  outputs_b = probs[j]
51
+ gt_a = gts[i].copy()
52
+ gt_b = gts[j].copy()
53
 
54
  # Format softmax outputs
55
  if metric_name == "CAPA":
56
  outputs_a = [softmax(logits) for logits in outputs_a]
57
  outputs_b = [softmax(logits) for logits in outputs_b]
58
 
59
+ # Remove indices where the ground truth differs
60
+ # (This code assumes gt_a and gt_b are lists of integers.)
61
+ indices_to_remove = [idx for idx, (a, b) in enumerate(zip(gt_a, gt_b)) if a != b]
 
 
 
62
  for idx in sorted(indices_to_remove, reverse=True):
63
  del outputs_a[idx]
64
  del outputs_b[idx]
65
  del gt_a[idx]
66
+ del gt_b[idx]
67
 
68
  try:
69
  similarities[i, j] = compute_similarity(metric, outputs_a, outputs_b, gt_a)