michaelsdr / momentumnet

Drop-in replacement for any ResNet with a significantly reduced memory footprint and better representation capabilities
https://michaelsdr.github.io/momentumnet/
MIT License
207 stars 19 forks source link

Bug in pytorch 1.10 #21

Open pierreablin opened 2 years ago

pierreablin commented 2 years ago

The following code leads to an error using torch version 1.10 but not 1.9

import torch
from torch import nn
from momentumnet import MomentumNet

init_function = nn.Tanh()
functions = [nn.Linear(3, 3), ]

net = MomentumNet(functions, init_function=init_function, init_speed=True, gamma=0.9,
                                  use_backprop=False)

x = torch.randn(1, 3)

output = net(x).sum()
output.backward()

Changing init_function to something with parameters like

import torch
from torch import nn
from momentumnet import MomentumNet

init_function = nn.Linear(3, 3)
functions = [nn.Linear(3, 3), ]

net = MomentumNet(functions, init_function=init_function, init_speed=True, gamma=0.9,
                                  use_backprop=False)

x = torch.randn(1, 3)

output = net(x).sum()
output.backward()

works fine.

No idea what causes this.