Spaces:
Runtime error
Runtime error
import torch | |
def mic_acc_cal(preds, labels): | |
if isinstance(labels, tuple): | |
assert len(labels) == 3 | |
targets_a, targets_b, lam = labels | |
acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \ | |
+ (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds) | |
else: | |
acc_mic_top1 = (preds == labels).sum().item() / len(labels) | |
return acc_mic_top1 |