Closed jcburnel closed 2 years ago
Hi @jcburnel,
I believe this is due to the Dropout, if you execute the following code:
import numpy as np
import matplotlib.pyplot as plt
import torch
from momentumnet import transform_to_momentumnet
seq_size = 15
d = 512
n_layers = 12
bs = 10
transformer = torch.nn.Transformer(num_encoder_layers=n_layers, num_decoder_layers=n_layers)
src = torch.rand((seq_size, bs, d))
tgt = torch.rand((seq_size, bs, d))
print(transformer(src, tgt) - transformer(src, tgt))
It does not print zeros.
In eval mode you can check that the momentum transformer with gamma=0 has basically the same output.
In this example setting gamma=0.0 lead to mnet1 and net having two different outputs.
However, everything is fine with resnet18.
(it seems that it apply the residual connection of ResBlock after TransformMemory)