Open chiamp opened 5 months ago
This works for 2D inputs, but for higher dimensional tensors, this would only reduce the trailing dimension.
Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it?
Should we add this into GroupNorm as well?
Yeah, sounds like a good idea! Thanks for looking into this.
Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it?
After internal discussion, we have decided to keep the default reduction_axes=-1
LayerNorm
is understood as normalization the activations by reducing across all non-batch axes. Currently Flax's implementation ofLayerNorm
, the defaultreduction_axes=-1
. This works for 2D inputs, but for higher dimensional tensors, this would only reduce the trailing dimension. Should we change the default implementation so that it normalizes all non-batch axes by default (assuming the leading dimension is the batch axes)? This also applies toRMSNorm
as well.Another thing is that currently all normalization layers with learnable scale and bias have a
feature_axis
(or equivalent) input arg so that the user can specify the shape of the learnable params, exceptGroupNorm
(which always definesfeature_axis=-1
). Should we add this intoGroupNorm
as well?