ELEKTRONN / elektronn3

A PyTorch-based library for working with 3D and 2D convolutional neural networks, with focus on semantic segmentation of volumetric biomedical image data
MIT License
161 stars 27 forks source link

UNet Implementation details #52

Open SebastienTs opened 1 year ago

SebastienTs commented 1 year ago

I noticed some subtle differences in the implementation of your UNet and the one in this repo: https://github.com/johschmidt42/PyTorch-2D-3D-UNet-Tutorial.

Basically, the batch normalization is performed after the ReLU layers instead of the convolutional layers. Also, the weight are initialized by Xavier_uniform method instead of being set constant.

For my application it leads to better accuracy and speeds up the training by quite some margin. Have you been testing both approaches and do you have a strong opinion about it?

Related to this, I'm wondering if you would consider including some new UNet variants such as pre-trained encoder and hybrid networks with transformer blocks.

If not, is there a repository that you would particularly recommend to import Pytorch models from for bioimage analysis applications (3D fluorescence LM and 3D EM)?

mdraw commented 1 year ago

I have chosen this order of BN / ReLU layers because that was used in other implementations that I looked at (including the paper that originally proposed batch norm) and at that time I knew about optimized fused conv+bn+relu ops but had not seen conv+relu+bn ops. But if you think about it I guess BN after ReLU makes more sense because ReLU after BN kind of defeats the purpose of normalization: inputs are zero mean, unit variance and ReLU just sets all negative elements to zero (i.e. discarding half of the input elements on average) and reduces variance by half. That doesn't really make sense so if you don't have to stay compatible with earlier neural network architectures I would indeed suggest to switch the order. I also just found a good reddit discussion about this: https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/

Regarding initialization, I would just do experiments with different initializations and use what works best for you. The optimal initialization can vary depending on the architecture.

I haven't done extensive tests of initialization and batch norm placement myself. I can only say that enabling batch norm makes trainings converge much faster consistently than disabling it, but that's not that surprising I guess :)

I am not planning to include more recent neural network architectures because in my experience almost none of the architectures that have been published since the U-Net / 3D U-Net paper actually work better than a well-adapted U-Net, despite of their own reported benchmark results.

If you are looking for well-written fancy new network architecture implementations, I suggest the collection in https://github.com/Project-MONAI/MONAI/tree/dev/monai/networks/nets. You can just install monai and import the network models from the monai package and use them in the same way as the elektronn3-included models without any further changes because they are also just torch.nn.Module subclasses.

SebastienTs commented 1 year ago

Thanks for elaborating on this topic!

Since this is the only additional difference I could figure out between the UNet previously mentioned (with inverted ReLU / BN) and Elektronn3 UNet, do you confirm that the dummy attention layers play absolutely no role (as the name suggests)?

It's possibly an extreme case, but for my application which consists in filling up tubes from a membrane staining, the UNet with inverted BN brings a satisfying solution while I can't manage to obtain similar results with elektronn3 UNet, even for longer training and deeper/wider networks.

Regarding other UNet variants, I'm also under the impression that most do not bring much improvement as compared to a well configured/trained vanilla UNet, at least for bioimage analysis where whole objects often have limited extension and tend to be uniformely textured.

This blog attempts to schematically summarize the differences between the main UNet variants: https://link.medium.com/qc9T5MI5nCb

The exact application case is unfortunately not described and the treatment is rather superficial but the author also suggests that the improvements are marginal, with only a mention to Attention U-Net and U-Net3+ getting a more substancial edge.

I believe that your implementation of the attention UNet is similar to the architecture described in the blog, could you confirm it?

I couldn't find any parametric Pytorch implementation of a 3D UNet3+, do you know any?