sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
589 stars 152 forks source link

replace BatchNorm with LayerNorm #1035

Open bkmi opened 8 months ago

bkmi commented 8 months ago

Is your feature request related to a problem? Please describe. BatchNorm eliminates the iid assumption. That assumption is at the core of all objective functions we use in this library.

Describe the solution you'd like We should switch to using LayerNorm or GroupNorm wherever possible.

Describe alternatives you've considered Give people a choice, but it's a niche issue and tbh I see no reason to use BatchNorm except legacy concerns.

Additional context Typically classifiers have the option in this format: use_batch_norm: bool = False, we should simply change that to use_layer_norm.

bkmi commented 7 months ago

@janfb wanted to see a side-by-side comparison of the two with LayerNorm and BatchNorm before making the change.

francois-rozet commented 7 months ago

To comment on that issue, with BatchNorm one should be very careful to always feed batches that represent "modes" in the same proportion at the original distribution. For example, for a binary classifier, it is invalid to evaluate the positive and negative batches separately during training as it therefore becomes enough to identify the class of a single element in the batch to decide the class of the entire batch.