ShusenTang / Dive-into-DL-PyTorch

本项目将《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。
http://tangshusen.me/Dive-into-DL-PyTorch
Apache License 2.0
18.17k stars 5.38k forks source link

3.5 kaggle房价预测代码疑问 #135

Closed HenryYuen128 closed 4 years ago

HenryYuen128 commented 4 years ago

bug描述 请问下列代码中的torch.max()和torch.sqrt()代码为什么不会报错? 我尝试了解max和sqrt函数的作用,单独尝试torch.max(tensorA, torch.tensor(1.0))会报错: Expected object of scalar type Long but got scalar type Float for argument #2 'other'

在PyTorch官方文档中,torch.sqrt()只有一个参数。

def log_rmse(net, features, labels): with torch.no_grad():

将小于1的值设成1, 使得取对数时数值更稳定

    clipped_preds = torch.max(net(features), torch.tensor(1.0))
    rmse = torch.sqrt(loss(clipped_preds.log(), labels.log()))
return rmse.item()
ShusenTang commented 4 years ago

“请问下列代码中的torch.max()和torch.sqrt()代码为什么不会报错?” 哪个代码? torch.max(tensorA, torch.tensor(1.0))报的错说的很清楚就是类型不匹配一个是long一个是float。 rmse = torch.sqrt(loss(clipped_preds.log(), labels.log()))这里是传入的一个参数

HenryYuen128 commented 4 years ago

“请问下列代码中的torch.max()和torch.sqrt()代码为什么不会报错?” 哪个代码? torch.max(tensorA, torch.tensor(1.0))报的错说的很清楚就是类型不匹配一个是long一个是float。 rmse = torch.sqrt(loss(clipped_preds.log(), labels.log()))这里是传入的一个参数

谢谢!