tengshaofeng / ResidualAttentionNetwork-pytorch

a pytorch code about Residual Attention Network. This code is based on two projects from
678 stars 165 forks source link

During the test, cifar10, the output data structure is incorrect. #31

Open xk88372527 opened 4 years ago

xk88372527 commented 4 years ago

test code:

print('%s :Accuracy of the model on the test images: %d %%' % (datetime.now(),100 * correct / total))

# print('Accuracy of the model on the test images:', correct.item()/total)
# print(correct.item())
# print(total)
# for i in range(10):
#     print('%s :Accuracy of %5s : %2d %%' % (
#         datetime.now(),classes[i],  class_correct[i].item() / class_total[i]))
#     print(class_correct[i].item())
#     print(class_total[i])
# return correct / total

out: D:\Microsoft Visual Studio\Shared\Anaconda3_64\envs\xk\lib\site-packages\torch\nn\modules\upsampling.py:129: UserWarning: nn.UpsamplingBilinear2d is deprecated. Use nn.functional.interpolate instead. warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name)) 2020-03-31 15:32:25.001979 :Accuracy of the model on the test images: 95 % Accuracy of the model on the test images: 0.954 9540 10000 2020-03-31 15:32:25.002979 :Accuracy of plane : 0 % 194 1000.0 2020-03-31 15:32:25.002979 :Accuracy of car : 0 % 206 1000.0 2020-03-31 15:32:25.002979 :Accuracy of bird : 0 % 169 1000.0 2020-03-31 15:32:25.002979 :Accuracy of cat : 0 % 136 1000.0 2020-03-31 15:32:25.002979 :Accuracy of deer : 0 % 187 1000.0 2020-03-31 15:32:25.002979 :Accuracy of dog : 0 % 159 1000.0 2020-03-31 15:32:25.003980 :Accuracy of frog : 0 % 204 1000.0 2020-03-31 15:32:25.003980 :Accuracy of horse : 0 % 197 1000.0 2020-03-31 15:32:25.003980 :Accuracy of ship : 0 % 205 1000.0 2020-03-31 15:32:25.003980 :Accuracy of truck : 0 % 203 1000.0

If don't add ‘.item()’ The output will become: D:\Microsoft Visual Studio\Shared\Anaconda3_64\envs\xk\lib\site-packages\torch\nn\modules\upsampling.py:129: UserWarning: nn.UpsamplingBilinear2d is deprecated. Use nn.functional.interpolate instead. warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name)) 2020-03-31 15:38:02.784257 :Accuracy of the model on the test images: 95 % Accuracy of the model on the test images: tensor(0, device='cuda:0') 9540 10000 2020-03-31 15:38:02.785258 :Accuracy of plane : 0 % tensor(194, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.786259 :Accuracy of car : 0 % tensor(206, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.786259 :Accuracy of bird : 0 % tensor(169, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.787259 :Accuracy of cat : 0 % tensor(136, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.787259 :Accuracy of deer : 0 % tensor(187, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.788261 :Accuracy of dog : 0 % tensor(159, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.789261 :Accuracy of frog : 0 % tensor(204, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.789261 :Accuracy of horse : 0 % tensor(197, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.789261 :Accuracy of ship : 0 % tensor(205, device='cuda:0', dtype=torch.uint8) 1000.0 2020-03-31 15:38:02.790264 :Accuracy of truck : 0 % tensor(203, device='cuda:0', dtype=torch.uint8) 1000.0

I hope to get your help. thanks

Ixuanzhang commented 3 years ago

将train.py中的test函数中: c = (predicted == labels.data).squeeze() 改为 c = (predicted == labels.data).squeeze().tolist() 即可