orobix / Prototypical-Networks-for-Few-shot-Learning-PyTorch

Implementation of Prototypical Networks for Few Shot Learning (https://arxiv.org/abs/1703.05175) in Pytorch
MIT License
986 stars 210 forks source link

error at accuracy #23

Open hahmyg opened 4 years ago

hahmyg commented 4 years ago

Dear author

Thank you for your carefully written code. I re-use your some codes, and I found out the error

please check the line 84 in prototypical_loss.py I think y_hat should be sequeezed with squeeze()

y_hat and target_inds.squeeze() look like:

y_hat = torch.tensor([[0],[1],[2],[0],[4]])
target_inds.squeeze() = torch.tensor([0, 1, 2, 3, 4])

In this case,

y_hat.eq(target_inds.squeeze()).float()

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.]])

In this case, accuracy is 0.2

It should be tensor([1., 1., 1., 0., 1.]). In this case, accuracy is 0.8