chrispher / chrispher.github.com

Data Science
http://www.datakit.cn
9 stars 2 forks source link

pytorch初步 #6

Open chrispher opened 6 years ago

chrispher commented 6 years ago

http://www.datakit.cn/blog/2017/03/03/pytorch_01_basic.html

Jason-py commented 6 years ago

请问,为什么运行第3部分的代码

简单的case

x = Variable(torch.ones(1), requires_grad = True) y = x * 2 + 3 y.backward(retain_variables=True) print(x.grad) # 梯度值是 2x = 2

复杂一点,

target = torch.FloatTensor([10]) y.backward(target, retain_variables=True) print(x.grad) # 梯度值是 2*x = 20, 因为retain的设置为true, grad会加上原来的梯度值2, 结果是22

会报错 TypeError: backward() got an unexpected keyword argument 'retain_variables'