microsoft / FocalNet

[NeurIPS 2022] Official code for "Focal Modulation Networks"
MIT License
682 stars 61 forks source link

Use of nn.LayerNorm by FocalNet for a segmentation task and its alternatives #13

Closed shahzad-ali closed 1 year ago

shahzad-ali commented 1 year ago

Excellent work!

By default, layer normalization is used as FocalNet(norm_layer=nn.LayerNorm). I'm wondering if it's a better choice for a semantic segmentation task. I would love to hear some thoughts on this.

Strangely enough, setting norm_layer=nn.BatchNorm2d caused several errors since x in nn.BatchNorm2d(embed_dim)(x) was found to be a 3D Tensor with embed_dim as its last dimension.

  1. If nn.LayerNorm is supposed to be the default normalization for FocalNet, then why do we even have it as an input parameter?
  2. If one wishes to use a different normalization, is there a quick fix?

Looking forward to getting awesome replies. Thanks!

jwyang commented 1 year ago

Hi, @shahzad-ali , it is incompatible if you directly replace layernorm with bacthnorm2d because the dimension arrangements are different. You need to do some permutations for the input tensors accordingly. A simpler way is to use BatchNorm1d. This way you only need to transpose the input in the last two dimensions before each norm and transpose back right afterward.

jwyang commented 1 year ago

One more comment is that it has been demonstrated that our FocalNets with layernorm work pretty well for semantic segmentation on ADE20K. You may have a direct try with the provided checkpoints without any changes.