probabilists / zuko

Normalizing flows in PyTorch
https://zuko.readthedocs.io
MIT License
274 stars 20 forks source link

Batch normalization #46

Open jmm34 opened 3 months ago

jmm34 commented 3 months ago

I've found the Zuko library to be extremely beneficial for my work. I sincerely appreciate the effort that has gone into its development. In the Masked Autoregressive Flow paper (NeurIPS, 2017), the authors incorporated batch normalization following each autoregressive layer. Could this modification be integrated into the MaskedAutoregressiveTransform function?

francois-rozet commented 3 months ago

Hello @jmm34, thanks for the kind words.

I am not a fan of batch normalization as it often leads to train/test gaps which are hard to diagnose, but I see why one would want to use it (mainly faster training).

IMO the best way to add batch normalization in Zuko would be to implement a standalone (lazy) BatchNormTranform. The user can then insert batch norm transformations anywhere in the flow.

We would accept a PR that implements this.

Edit: I think that using the current batch statistics to normalize is invalid as it would not be an invertible transformation $y = f(x)$ (impossible to know $x$ given $y$). So, we should use running statistics both during training and evaluation, and update these statistics during training. Also, I am not sure that the scale and shift parameters are relevant (mean zero, unit variance is the target).

jmm34 commented 3 months ago

Dear @francois-rozet, thank you very much for your quick reply. I will try to make a PR using the strategy you suggest.