ShusenTang / Dive-into-DL-PyTorch

本项目将《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。
http://tangshusen.me/Dive-into-DL-PyTorch
Apache License 2.0
18.07k stars 5.37k forks source link

3.9.4节W1和W2的类型问题 #156

Closed Liang-Liao closed 3 years ago

Liang-Liao commented 3 years ago

bug描述 按3.9.4小节的net函数来计算,在我的环境里面报错了

def net(X):
    X = X.view((-1, num_inputs))
    H = relu(torch.matmul(X, W1) + b1)
    return torch.matmul(H, W2) + b2

loss = torch.nn.CrossEntropyLoss()

num_epochs, lr = 5, 100.0
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

错误如下:

RuntimeError  Traceback (most recent call last)
<ipython-input-52-c1201a53ebe9> in <module>
      1 num_epochs, lr = 5, 100.0
----> 2 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

~/liang/d2lzh_pytorch.py in train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr, optimizer)
     84         train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
     85         for X, y in train_iter:
---> 86             y_hat = net(X)
     87             l = loss(y_hat, y).sum()
     88 

<ipython-input-50-c182b51c4bb0> in net(X)
      1 def net(X):
      2     X = X.view((-1, num_inputs))
----> 3     H = relu(torch.matmul(X, W1) + b1)
      4     return torch.matmul(H, W2) + b2

RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat2' in call to _th_mm

版本信息 pytorch: 1.4.0 torchvision:0.5.0 torchtext: ...

LuckyyySTA commented 3 years ago

我和您出现了一样的问题,请问您的问题解决了吗

Liang-Liao commented 3 years ago

https://github.com/ShusenTang/Dive-into-DL-PyTorch/issues/156#issuecomment-691947265

我在 net() 方法里面的 W1 和 W2 加上了 float() 转换就行了。如下:

def net(X):
    X = X.view((-1, num_inputs))
    H = relu(torch.matmul(X, W1.float()) + b1)
    return torch.matmul(H, W2.float()) + b2
LuckyyySTA commented 3 years ago

#156 (comment)

我在 net() 方法里面的 W1 和 W2 加上了 float() 转换就行了。如下:

def net(X):
    X = X.view((-1, num_inputs))
    H = relu(torch.matmul(X, W1.float()) + b1)
    return torch.matmul(H, W2.float()) + b2

在您的帮助下成功运行了,十分感谢!