Open rwightman opened 2 years ago
To be more specific GroupNorm w/ groups=1 normalizes over C, H, W. LayerNorm as used in transformers normalizes over the channel dimension only. Since PyTorch LN doesn't natively support 2d rank-4 NCHW tensors, a 'LayerNorm2d' impl (ConvNeXt, EdgeNeXt, CoaTNet, and many more) is often used that either manually calcs mean/var over C dim or permutes to NHWC and back. In either case the norm remains over just channel dim.
GroupNorm(C, groups=1, affine=False) == LayerNorm([C, H, W], elementwise_affine=False) NOT LayerNorm(C) w/ permute.
Additionaly, if the affine scale/bias is enabled, there is no way to get equivalence as groupnorm scales/shifts over C dim, while LayerNorm will apply to all of C, H, W in the case where LN == GN(groups=1).
We use the relationship between GroupNorm and LayerNorm, as described in GroupNorm paper. This is also consistent with PyTorch's documentation, which also suggests that putting all channels in one group is equivalent to layer norm. We will clarify it in the documentation.
Thanks for the suggestions. We will re-train MobileViTv2 with ConvNext-style layernorm and also rename the LayerNorm2D
as group norm (to be consistent with other works and implementations)
@sacmehta the equivalence for GN and LN as per the paper is for NCHW tensors when LN is performed over all of C, H, W (minus the affine part as mentioned). However, the LN in transformers, including when used with 2D NCHW tensors is usually over just C.
There is nothing at all wrong with what you've implemented, it may well be better, but calling a LN is a bit confusing given other uses and difference in how affine params are applied. PoolFormer is using the same as you but theirs is just called GroupNorm (w/ groups forced to 1), I called it GroupNorm1 when I used it (not sure that makes it any more clear though, heh).
There would be a few fewer flops in the LN over C only case, but unfortunately with no efficient impl for PyTorch, the permutes required can slow things down a bit. In either case I'd be curious to see if the accuracy changes.
@rwightman Thanks for the feedback. I will change the name of files for clarity.
I will keep you posted about LN experiments.
@rwightman I did experiments by replacing LayerNorm2d (or GroupNorm1) with LayerNorm (transformer or ConvNext style) and found that it makes training unstable and also hurts the performance.
Attached are the training and validation loss curves (I did some hyper-parameter tweaking for new experiments, but training instability was kind of consistent across experiments). I used LayerNorm implementation from ane-transformers.
Training loss comparison
Validation loss comparison
@sacmehta thanks for the update, looks like the channels-only LN is definitely not stable in this architecture.
I tried to implement mobilevit v2 with tensorflow 2.x, and layernorm uses layers.LayerNormalization(epsilon=1e-6). Comparing the output of each layer, I found that it is inconsistent with the output of the pytorch version of layernorm
https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalization
Later, I changed to tfa.layers.GroupNormalization(Addons), set groups=1, and checked the output of each layer, which is consistent with the layernorm of the pytorch version;
https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization
I checked the transformer code implemented by the keras team, they are using layers.LayerNormalization(epsilon=1e-6)
https://keras.io/examples/vision/image_classification_with_vision_transformer/
I use the mobilevit v2 of the tensorflow 2.x version I reimplemented, and the check output is consistent with the pytorch version. During the transfer learning, the loss does not decrease (batchsize=64)
Re your MobileVit2, these two norms are not equivalent and it would be misleading to call it LayerNorm2d as the group norm w/ groups=1 is not equivalent. 'LayerNorm2d' is already used elsewhere in other nets. Might be worth retraining MobileVit2 with an actual LayerNorm or renaming the norm to just GroupNorm.
https://github.com/apple/ml-cvnets/blob/84d992f413e52c0468f86d23196efd9dad885e6f/cvnets/layers/normalization/layer_norm.py#L56