jeya-maria-jose / UNeXt-pytorch

Official Pytorch Code base for "UNeXt: MLP-based Rapid Medical Image Segmentation Network", MICCAI 2022
https://jeya-maria-jose.github.io/UNext-web/
MIT License
459 stars 76 forks source link

Not as fast as the paper says #2

Closed JOP-Lee closed 2 years ago

JOP-Lee commented 2 years ago

Thank you for your sharing. After my experiment, I found that the speed was quite slow. May I ask what caused the problem? Why do the OverlapPatchEmbed and shiftMLP modules both use convolution in the MLP phase, and the convolution kernel size is 3 and 7, which makes the speed slow.

jeya-maria-jose commented 2 years ago

Hi, Can you please provide more details? The FLOPs of the model will match with the numbers as reported in the paper. If you have the same hardware as reported in paper, you will be able to match the inference speed too. What do you mean by "quite slow"?

The answer to why there are DWConv layers in MLP phase is clearly explained in the paper. Check out page 5. Thanks.

JOP-Lee commented 2 years ago

@jeya-maria-jose Hi, thanks for reply. I used the MLP stage in the paper instead of U-net for another task, training speed is not improved. Does this method not improve speed during training? I quite agree with this sentence in the article "the embedding dimension H which is significantly less than the dimensionality of the feature maps (H/N)x(H/N) where N is a factor of 2 depending on the block". However, OverlapPatchEmbed and shiftMLP modules both use convolution in the MLP phase,

x = self.dwconv(x, H, W)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))

Conv2d was also used at the bottom of the original U-Net, and I didn't understand how it increased the speed, by reducing the parameters for In in_chans and embed_dim?

jeya-maria-jose commented 2 years ago

Hi @JOP-Lee , Thanks for the clear explanation. In the original UNet architecture, note that there are two conv layers across each block. Replacing a conv layer with the MLP block brings an increase in speed. If we consider Overlap patch embedding as one conv layer, the other one is replaces by the MLP block. I agree that it does contain DWConv but has far less computations compared to original conv. Overall, we show that with a combination of conv and MLP blocks, we can achieve a better performance with fast inference compared to original configuration.