floodsung / LearningToCompare_FSL

PyTorch code for CVPR 2018 paper: Learning to Compare: Relation Network for Few-Shot Learning (Few-Shot Learning part)
MIT License
1.04k stars 268 forks source link

RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index' #30

Closed haofengsiji closed 4 years ago

haofengsiji commented 5 years ago

I have already tried my best to solve this problem. But I failed, can anyone give me some suggestions?

(py36) E:\GitHub\LearningToCompare_FSL>F:/Anaconda3/envs/py36/python.exe e:/GitHub/LearningToCompare_FSL/omniglot/omniglot_train_few_shot.py
init data folders
init neural networks
Training...
F:\Anaconda3\envs\py36\lib\site-packages\torch\nn\functional.py:1386: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
  File "e:/GitHub/LearningToCompare_FSL/omniglot/omniglot_train_few_shot.py", line 264, in <module>
    main()
  File "e:/GitHub/LearningToCompare_FSL/omniglot/omniglot_train_few_shot.py", line 188, in main
    one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1,1), 1).cuda(GPU))
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'
yitianhoulai commented 4 years ago

The 3rd argument index should be batch_labels.view(-1,1) and you just need to change it to

_one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM, CLASSNUM).scatter(1, batchlabels.view(-1,1).long(), 1).cuda(GPU))

haofengsiji commented 4 years ago

py27 may also solve this problem, but still, thanks