Closed johannbrehmer closed 4 years ago
Hi Johann.
Very good catch! Thank you for a detailed explanation. Yes, this makes sense. I would expect a model to not be re-initialized when we continue training after loading from a checkpoint.
It should indeed be registered as a buffer. If you could do a quick PR that would be great. Otherwise I am happy to push the fix myself.
I do also wonder if we could check this behaviour in the test. I.e. save/load the state of ActNorm
, and make sure log_scale
, shift
don't change after passing another batch through.
Hi,
First of all, I'm a huge fan of this code base. Great work.
I stumbled about a behaviour of actnorm layers that may not be intended.
Issue
When you save a model with actnorm layers, the flag
ActNorm.initialized
is not part of the state dict (since it's not registered as a buffer). When you then load a model from the state dict and continue training, the scale and shift of the actnorm layers are re-initialized to the mean and standard deviation of the activations.Expected behaviour
I would expect that saving a state dict, loading the state dict, and continuing training to behave in the same way as training in one go, i.e. without re-initializing the scale and shift parameters. That would also be more consistent with the behaviour of
BatchNorm
.Fix
In https://github.com/bayesiains/nflows/blob/ff5e9bb715a118581a171f18f5fca1d3d526b4a5/nflows/transforms/normalization.py#L157, replace
self.initialized = False
withself.register_buffer("initialized", torch.zeros(1, dtype=torch.bool))
, and adapting https://github.com/bayesiains/nflows/blob/ff5e9bb715a118581a171f18f5fca1d3d526b4a5/nflows/transforms/normalization.py#L219 accordingly.Does that make sense or is this behaviour intended? I'm happy to file a PR if you want me to.
Cheers, Johann