floodsung / LearningToCompare_FSL

PyTorch code for CVPR 2018 paper: Learning to Compare: Relation Network for Few-Shot Learning (Few-Shot Learning part)
MIT License
1.04k stars 268 forks source link

Question about calculating accuracy #42

Open datduonguva opened 3 years ago

datduonguva commented 3 years ago

In calculating accuracy of test dataset: https://github.com/floodsung/LearningToCompare_FSL/blob/master/omniglot/omniglot_train_one_shot.py#L237

sample_images,sample_labels = sample_dataloader.__iter__().next()                
test_images,test_labels = test_dataloader.__iter__().next()

sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64                
test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)                test_features_ext = test_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)                test_features_ext = torch.transpose(test_features_ext,0,1)
relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,5,5)                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
 _,predict_labels = torch.max(relations.data,1)
 rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(CLASS_NUM)]

I think the reward must be summed over all images in the batch size, so the j in the last line should be in range(len(test_labels))

Why was it sum over j in CLASS_NUM?