michaelsdr / momentumnet

Drop-in replacement for any ResNet with a significantly reduced memory footprint and better representation capabilities
https://michaelsdr.github.io/momentumnet/
MIT License
207 stars 19 forks source link

Error in "drop_in_replacement_tutorial.py" ? #25

Closed jcburnel closed 2 years ago

jcburnel commented 2 years ago

In this example setting gamma=0.0 lead to mnet1 and net having two different outputs.

However, everything is fine with resnet18.

(it seems that it apply the residual connection of ResBlock after TransformMemory)

michaelsdr commented 2 years ago

Hi @jcburnel,

I believe this is due to the Dropout, if you execute the following code:

import numpy as np
import matplotlib.pyplot as plt
import torch
from momentumnet import transform_to_momentumnet

seq_size = 15
d = 512
n_layers = 12
bs = 10

transformer = torch.nn.Transformer(num_encoder_layers=n_layers, num_decoder_layers=n_layers)

src = torch.rand((seq_size, bs, d))
tgt = torch.rand((seq_size, bs, d))

print(transformer(src, tgt) - transformer(src, tgt))

It does not print zeros.

In eval mode you can check that the momentum transformer with gamma=0 has basically the same output.