edgarschnfld / CADA-VAE-PyTorch

Official implementation of the paper "Generalized Zero- and Few-Shot Learning via Aligned Variational Autoencoders" (CVPR 2019)
MIT License
283 stars 57 forks source link

There is a bug in final_classifier.py #32

Open Bingyang0410 opened 1 year ago

Bingyang0410 commented 1 year ago

In this line 291: per_class_accuracies[i] = torch.div((predicted_label[is_class]==test_label[is_class]).sum().float(),is_class.sum().float())

if the output have some NaN value, the output of the accuracy maybe Nan

Thus, I think the code need to change as follows:

if torch.any(torch.isnan(per_class_accuracies[i])):

      per_class_accuracies[i] =  torch.where(torch.isnan(per_class_accuracies[i]),torch.full_like(per_class_accuracies[i],0),per_class_accuracies[i])