sail-sg / poolformer

PoolFormer: MetaFormer Is Actually What You Need for Vision (CVPR 2022 Oral)
https://arxiv.org/abs/2111.11418
Apache License 2.0
1.29k stars 117 forks source link

About Normalization #9

Closed chenhang98 closed 2 years ago

chenhang98 commented 2 years ago

Hi, thanks for your excellent work. In your ablation studies (section 4.4), you compared Group Normalization (group number is set as 1 for simplicity), Layer Normalization, and Batch Normalization. The conclusion is that Group Normalization is 0.7% or 0.8% higher than Layer Normalization or Batch Normalization. But when the number of groups is 1, Group Normalization is equivalent to Layer Normalization, right?

yuweihao commented 2 years ago

Hi @tinyalpha ,

In transformers, the conventional data format is [B, W*H, C]. Transformers usually use Layer Normalization that only applies to the channel dimension. If the data format is [B, C, W, H], like that in this repo, it means it only normalizes C, not including [W, H]. Currently, there is no PyTorch API LayerNorm for [B, C, W, H] data format to only normalize channel dimension, I implement it here and name it LayerNormChannel. The Layer Normalization in the paper refers to this conventional Layer Normalization of transformers.

For data format [B, C, W, H], Group Normalization normalizes dimensions C, W, and H at the same time.

In summary (run in poolformer root dir):

import torch
import torch.nn as nn
from models.poolformer import LayerNormChannel # only normalize channel
from models.poolformer import GroupNorm # group == 1

C, W, H = 64, 56, 56 # data format [B, C, W, H]

layer_norm_channel = LayerNormChannel(C) # only normalize C, conventionally used by transformers and in this paper ablation study
print(layer_norm_channel.weight.shape) # learnable affine weight shape: torch.Size([64])

group_norm = GroupNorm(C) # normalize C, H, and W.
print(group_norm.weight.shape) # learnable affine weight shape: torch.Size([64])

layer_norm = nn.LayerNorm([C, H, W]) # normalize C, H, and W.
print(layer_norm.weight.shape) # learnable affine weight shape: torch.Size([64, 56, 56])

You may clearly know their differences from the above code. Therefore, "when the number of groups is 1, Group Normalization is equivalent to Layer Normalization" is NOT right. There is a small mistake in the PyTorch GroupNorm API explanation. The saying "Put all 6 channels into a single group (equivalent with LayerNorm)" in the API explanation is NOT right.

chenhang98 commented 2 years ago

Thanks for your explanation!

yuweihao commented 2 years ago

You are welcome~

ma-xu commented 2 years ago

This really helps. Thanks a lot.

iumyx2612 commented 1 year ago

So it means, apply nn.LayerNorm on (B, C, H, W) is equals to GroupNorm with groups = 1 right?

The LayerNorm that decreases the performance in the ablation study is the LayerNormChannel not nn.LayerNorm right?

yuweihao commented 1 year ago

Hi @iumyx2612 , thanks for your attention.

  1. apply nn.LayerNorm" on (B, C, H, W) is NOT equal toGroupNormwith groups = 1. In the example given above,print(group_norm.weight.shape) # learnable affine weight shape: torch.Size([64])whileprint(layer_norm.weight.shape) # learnable affine weight shape: torch.Size([64, 56, 56])`. The difference is their affine weight and bias shape.
  2. LayerNormChannel for input (B, C, H, W) is equal to nn.LayerNorm for input (B, N, C). Because this repo uses data format (B, C, H, W), the LayerNorm implementation in this repo is LayerNormChannel corresponding to the results of the ablation study.
In summary, to identify whether two Normalizations are equal or not, the first is to see which dims are used to compute mean and variance, the second is to see whether their shape of affine weight and bias. For (B, C, H, W) input, Norm Which dims to compute mean and variance Shape of affine weight and bias
LayerNormChannel(num_channels=C) (C, ) (C, )
nn.GroupNorm(num_groups=1, num_channels=C) (C, H, W) (C, )
nn.LayerNorm(normalized_shape=(C, H, W)) (C, H, W) (C, H, W)
For (B, N, C) input (like timm's ViT implementation), Norm Which dims to compute mean and variance Shape of affine weight and bias
nn.LayerNorm(normalized_shape=C) (C, ) (C, )

You may also refer to another example that I implement general LayerNorm for different situations for data format (B, H, W, C).

iumyx2612 commented 1 year ago

Hi @iumyx2612 , thanks for your attention.

1. apply `nn.LayerNorm" on (B, C, H, W) is NOT equal to `GroupNorm`with groups = 1. In the example given above,`print(group_norm.weight.shape) # learnable affine weight shape: torch.Size([64])`while`print(layer_norm.weight.shape) # learnable affine weight shape: torch.Size([64, 56, 56])`. The difference is their affine weight and bias shape.

2. LayerNormChannel for input (B, C, H, W) is equal to `nn.LayerNorm` for input (B, N, C). Because this repo uses data format (B, C, H, W), the LayerNorm implementation in this repo is `LayerNormChannel` corresponding to the results of the ablation study.

In summary, to identify whether two Normalizations are equal or not, the first is to see which dims are used to compute mean and variance, the second is to see whether their shape of affine weight and bias. For (B, C, H, W) input, Norm Which dims to compute mean and variance Shape of affine weight and bias LayerNormChannel(num_channels=C) (C, ) (C, ) nn.GroupNorm(num_groups=1, num_channels=C) (C, H, W) (C, ) nn.LayerNorm(normalized_shape=(C, H, W)) (C, H, W) (C, H, W)

For (B, N, C) input (like timm's ViT implementation), Norm Which dims to compute mean and variance Shape of affine weight and bias nn.LayerNorm(normalized_shape=C) (C, ) (C, )

You may also refer to another example that I implement general LayerNorm for different situations for data format (B, H, W, C).

Thanks for the detail explanation, I can see nn.LayerNorm and nn.GroupNorm(num_groups=1) compute mean and variance on the same dimension. However, the weight for the Norm layers are different. Both nn.LayerNorm and nn.GroupNorm(num_groups=1) do the same job in this situation (the dims to compute mean and var are similar), so using nn.LayerNorm performs worse than nn.GroupNorm or better? Since the LayerNorm in the abalation study is LayerNormChannel (which only compute mean and var on C) not nn.LayerNorm (which compute mean and var similar as nn.GroupNorm)

yuweihao commented 1 year ago

Hi @iumyx2612 , yes, the difference between nn.LayerNorm(normalized_shape=(C, H, W)) and nn.GroupNorm(num_groups=1, num_channels=C) is their shapes of affine weight and bias. And the difference between nn.GroupNorm(num_groups=1, num_channels=C) and ViT's LayerNorm is the dims to compute mean and variance. For your question, I did not conduct experiments for nn.LayerNorm(normalized_shape=(C, H, W)) because it means the model can only accept one resolution. So I don't have experiment results on hand for the comparison between nn.LayerNorm(normalized_shape=(C, H, W)) and nn.GroupNorm(num_groups=1, num_channels=C) and sorry that I cannot give clear conclusion for it.

iumyx2612 commented 1 year ago

Hi @iumyx2612 , yes, the difference between nn.LayerNorm(normalized_shape=(C, H, W)) and nn.GroupNorm(num_groups=1, num_channels=C) is their shapes of affine weight and bias. And the difference between nn.GroupNorm(num_groups=1, num_channels=C) and ViT's LayerNorm is the dims to compute mean and variance. For your question, I did not conduct experiments for nn.LayerNorm(normalized_shape=(C, H, W)) because it means the model can only accept one resolution. So I don't have experiment results on hand for the comparison between nn.LayerNorm(normalized_shape=(C, H, W)) and nn.GroupNorm(num_groups=1, num_channels=C) and sorry that I cannot give clear conclusion for it.

Very clear explanation, thank you for taking your time to answer me! Much appreciated

yuweihao commented 1 year ago

Welcome~