google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.8k stars 611 forks source link

Standardizing normalization layers #3664

Open chiamp opened 5 months ago

chiamp commented 5 months ago

norm

LayerNorm is understood as normalization the activations by reducing across all non-batch axes. Currently Flax's implementation of LayerNorm, the default reduction_axes=-1. This works for 2D inputs, but for higher dimensional tensors, this would only reduce the trailing dimension. Should we change the default implementation so that it normalizes all non-batch axes by default (assuming the leading dimension is the batch axes)? This also applies to RMSNorm as well.

Another thing is that currently all normalization layers with learnable scale and bias have a feature_axis (or equivalent) input arg so that the user can specify the shape of the learnable params, except GroupNorm (which always defines feature_axis=-1). Should we add this into GroupNorm as well?

cgarciae commented 5 months ago

This works for 2D inputs, but for higher dimensional tensors, this would only reduce the trailing dimension.

Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it?

Should we add this into GroupNorm as well?

Yeah, sounds like a good idea! Thanks for looking into this.

chiamp commented 5 months ago

Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it?

After internal discussion, we have decided to keep the default reduction_axes=-1