ShusenTang / Dive-into-DL-PyTorch

本项目将《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。
http://tangshusen.me/Dive-into-DL-PyTorch
Apache License 2.0
18.25k stars 5.39k forks source link

3.6.6 评价net在data_iter上的准确性 #163

Open CKing111 opened 3 years ago

CKing111 commented 3 years ago

正常copy,出现下面问题,不知道怎么处理,求帮助

def evaluate_accuracy(data_iter, net): acc_sum, n = 0.0, 0 for X, y in data_iter: acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0] return acc_sum / n print(evaluate_accuracy(test_iter, net))


RuntimeError Traceback (most recent call last)

in 6 n += y.shape[0] 7 return acc_sum / n ----> 8 print(evaluate_accuracy(test_iter, net)) in evaluate_accuracy(data_iter, net) 3 acc_sum, n = 0.0, 0 4 for X, y in data_iter: ----> 5 acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 6 n += y.shape[0] 7 return acc_sum / n in net(X) 1 def net(X): ----> 2 return softmax(torch.mm(X.view((-1, num_inputs)), W) + b) RuntimeError: The size of tensor a (10) must match the size of tensor b (3) at non-singleton dimension 1 **版本信息** pytorch: torchvision: torchtext: ...