In this line 291:
per_class_accuracies[i] = torch.div((predicted_label[is_class]==test_label[is_class]).sum().float(),is_class.sum().float())
if the output have some NaN value, the output of the accuracy maybe Nan
Thus, I think the code need to change as follows:
if torch.any(torch.isnan(per_class_accuracies[i])):
per_class_accuracies[i] = torch.where(torch.isnan(per_class_accuracies[i]),torch.full_like(per_class_accuracies[i],0),per_class_accuracies[i])
In this line 291:
per_class_accuracies[i] = torch.div((predicted_label[is_class]==test_label[is_class]).sum().float(),is_class.sum().float())
if the output have some NaN value, the output of the accuracy maybe Nan
Thus, I think the code need to change as follows: