Open pierreablin opened 2 years ago
The following code leads to an error using torch version 1.10 but not 1.9
import torch from torch import nn from momentumnet import MomentumNet init_function = nn.Tanh() functions = [nn.Linear(3, 3), ] net = MomentumNet(functions, init_function=init_function, init_speed=True, gamma=0.9, use_backprop=False) x = torch.randn(1, 3) output = net(x).sum() output.backward()
Changing init_function to something with parameters like
import torch from torch import nn from momentumnet import MomentumNet init_function = nn.Linear(3, 3) functions = [nn.Linear(3, 3), ] net = MomentumNet(functions, init_function=init_function, init_speed=True, gamma=0.9, use_backprop=False) x = torch.randn(1, 3) output = net(x).sum() output.backward()
works fine.
No idea what causes this.
The following code leads to an error using torch version 1.10 but not 1.9
Changing init_function to something with parameters like
works fine.
No idea what causes this.