Closed NRSummerfield closed 2 years ago
Hi @NRSummerfield
Thanks for your detailed comment. The implementation is correct. We start counting the layers in a Pythonic way (i.e. from zero), hence the 4th layer (if counted from 1) corresponds to 3.
Best
Thank you @ahatamiz
According to the paper, shouldn't it start by pulling from the 3rd layer (counted from 1)? Using the 12 transformers as it is, the 4 branches do not have 3 transformers each but instead 4, 3, 3, and 2.
In this diagram, the assumption is that layers start from zero (e.g. z_0
), hence z_3
corresponds to hidden_states_out[3]
. Others also follow the same logic.
I recognize you wrote the paper so thank you for your help. Just for clarity, I am now confused on the branch z_12
.
In the implementation, z_12
corresponds to the output of self.vit
, which goes through transformer z_11
followed by a normalization layer. There are only 12 transformer layers so there are only 2 transformers between z_9
and z_12
. The architecture now follows a transformer layout of 4, 3, 3, 2 instead of what I'd assume would be a 3, 3, 3, 3 configuration of transformers between the branches. Is this correct?
Thank you for your help.
Hi @NRSummerfield
I appreciate your comment. Yes, as you correctly mentioned, the architecture follows a pattern of 4, 3, 3, 2. Since UNETR uses a monolithic architecture, choosing which layers to extract features from is not a straight-forward task and requires more analysis. Based on our experiments, a 4, 3, 3, 2 layout resulted in better segmentation performance in the initial experiments.
The MONAI implementation of UNETR closely follows the implementation that was used for the results as presented in the paper, hence allowing for reproducing the presented benchmarks.
Thanks
Okay, I understand. Thank you very much for your help!
I believe there is an error with the implementation of the UNETR network within monai.networks.nets.unetr.py
Issue From the forward function:
The transformers are called via indices 3, 6, and 9 resulting in the network pulling from transformers 4, 7, 10, and 12.
Expected behavior Meanwhile, the paper (https://arxiv.org/abs/2103.10504) states: "we extract a sequence representation zi (i ∈ {3,6,9,12}) ... from the transformer".
In the forward function, the transformer indices should be 2, 5, and 8.