bayesiains / nflows

Normalizing flows in PyTorch
MIT License
845 stars 118 forks source link

Actnorm does not register `initialized` as buffer #4

Closed johannbrehmer closed 4 years ago

johannbrehmer commented 4 years ago

Hi,

First of all, I'm a huge fan of this code base. Great work.

I stumbled about a behaviour of actnorm layers that may not be intended.

Issue

When you save a model with actnorm layers, the flag ActNorm.initialized is not part of the state dict (since it's not registered as a buffer). When you then load a model from the state dict and continue training, the scale and shift of the actnorm layers are re-initialized to the mean and standard deviation of the activations.

Expected behaviour

I would expect that saving a state dict, loading the state dict, and continuing training to behave in the same way as training in one go, i.e. without re-initializing the scale and shift parameters. That would also be more consistent with the behaviour of BatchNorm.

Fix

In https://github.com/bayesiains/nflows/blob/ff5e9bb715a118581a171f18f5fca1d3d526b4a5/nflows/transforms/normalization.py#L157, replace self.initialized = False with self.register_buffer("initialized", torch.zeros(1, dtype=torch.bool)), and adapting https://github.com/bayesiains/nflows/blob/ff5e9bb715a118581a171f18f5fca1d3d526b4a5/nflows/transforms/normalization.py#L219 accordingly.

Does that make sense or is this behaviour intended? I'm happy to file a PR if you want me to.

Cheers, Johann

arturbekasov commented 4 years ago

Hi Johann.

Very good catch! Thank you for a detailed explanation. Yes, this makes sense. I would expect a model to not be re-initialized when we continue training after loading from a checkpoint.

It should indeed be registered as a buffer. If you could do a quick PR that would be great. Otherwise I am happy to push the fix myself.

I do also wonder if we could check this behaviour in the test. I.e. save/load the state of ActNorm, and make sure log_scale, shift don't change after passing another batch through.