facebookresearch / ConvNeXt

Code release for ConvNeXt model
MIT License
5.71k stars 693 forks source link

Question 1x1 conv vs linear #18

Closed KatsarosEf closed 2 years ago

KatsarosEf commented 2 years ago

Congratulations on your work and thanks for sharing! I'd like to naively ask, what is the reason behind implementing 1x1 convs with fully connected layers? I know they are equivalent but I had been thinking the latter is less efficient.

Thanks in advance!

liuzhuang13 commented 2 years ago

Update: I retested the comparison and it seems the info I gave before is not exactly accurate (I remembered wrong, so I deleted it to avoid misleading anyone). Sorry about the confusion. I give my latest test observations on V100 GPU inference throughput below:

  1. Testing standalone MLPs, (NHWC -> linear layers) is ~5% faster than (NCHW -> 1x1 convs). This is tested with C=256, feature resolution=14, batch size 128.
  2. Tested in ConvNeXts, with LN disabled in blocks, using (NCHW -> 1x1 convs -> layerscale) is 0-5% (depending on resolution and model sizes) faster than (NCHW -> permute to NHWC -> linear layers -> layer scale-> permute back to NCHW). This reverse could be partly due to permutation, or other whole-model related properties. Interestingly if I use "channel_last" in PyTorch sometimes the later is faster at 384 resolution.
  3. Tested in ConvNeXts, with LN included in blocks, using (NCHW -> custom LN -> 1x1 convs -> layer scale) is now 0-5% slower than (NCHW -> permute to NHWC -> PyTorch LN -> linear layers -> layer scale -> permute back to NCHW). The custom LN is one we wrote that operates on NCHW tensors (PyTorch's LN only supports tensors with C as the last dimension).

Looking at 2 and 3, the ultimate reason why (NCHW -> permute to NHWC -> PyTorch LN -> linear layers -> layer scale -> permute back to NCHW) is slightly faster than (NCHW -> custom LN -> 1x1 convs -> layer scale), seems to be our custom LN layer operating on NCHW tensors is much slower than the PyTorch's LN that only supports operating on NHWC tensors.

So we need the permutation to NHWC anyway to use PyTorch's LN, and given the observation in 1 (without permutation linear is faster than 1x1 convs), we use linear layers before permuting it back to do the "MLP" part.

KatsarosEf commented 2 years ago

Thank you for your swift response, very detailed of an architectural design build, indeed I only noticed your comment on line 30 :). One last question if you don't mind and I am closing this issue, the choice of GeLU over ReLU is due to some dying neurons observations or was solely chosen based on the related Transformers' papers (BERT, GPT2) as mentionned? Is there a case you experimented with alternatives like Swish?

Many thanks again, all the best to your future works.

liuzhuang13 commented 2 years ago

The choice of GELU over RELU is in part due to imitating Transformers. Another interesting observation is if we stick to RELU, in the next step "Fewer activations" the training curve becomes a bit strange, despite it can converge to a reasonable level finally. We didn't try activations other than RELU and GELU