Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.71k stars 1.04k forks source link

Indexing error in UNETR #4520

Closed NRSummerfield closed 2 years ago

NRSummerfield commented 2 years ago

I believe there is an error with the implementation of the UNETR network within monai.networks.nets.unetr.py

Issue From the forward function:

def forward(self, x_in):
        x, hidden_states_out = self.vit(x_in)
        enc1 = self.encoder1(x_in)
        x2 = hidden_states_out[3]
        enc2 = self.encoder2(self.proj_feat(x2))
        x3 = hidden_states_out[6]
        enc3 = self.encoder3(self.proj_feat(x3))
        x4 = hidden_states_out[9]
        enc4 = self.encoder4(self.proj_feat(x4))
        dec4 = self.proj_feat(x)
        dec3 = self.decoder5(dec4, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        out = self.decoder2(dec1, enc1)
        return self.out(out)

The transformers are called via indices 3, 6, and 9 resulting in the network pulling from transformers 4, 7, 10, and 12.

Expected behavior Meanwhile, the paper (https://arxiv.org/abs/2103.10504) states: "we extract a sequence representation zi (i ∈ {3,6,9,12}) ... from the transformer".

In the forward function, the transformer indices should be 2, 5, and 8.

ahatamiz commented 2 years ago

Hi @NRSummerfield

Thanks for your detailed comment. The implementation is correct. We start counting the layers in a Pythonic way (i.e. from zero), hence the 4th layer (if counted from 1) corresponds to 3.

Best

NRSummerfield commented 2 years ago

Thank you @ahatamiz

According to the paper, shouldn't it start by pulling from the 3rd layer (counted from 1)? Using the 12 transformers as it is, the 4 branches do not have 3 transformers each but instead 4, 3, 3, and 2.

image
ahatamiz commented 2 years ago

In this diagram, the assumption is that layers start from zero (e.g. z_0), hence z_3 corresponds to hidden_states_out[3]. Others also follow the same logic.

NRSummerfield commented 2 years ago

I recognize you wrote the paper so thank you for your help. Just for clarity, I am now confused on the branch z_12. In the implementation, z_12 corresponds to the output of self.vit, which goes through transformer z_11 followed by a normalization layer. There are only 12 transformer layers so there are only 2 transformers between z_9 and z_12. The architecture now follows a transformer layout of 4, 3, 3, 2 instead of what I'd assume would be a 3, 3, 3, 3 configuration of transformers between the branches. Is this correct?

Thank you for your help.

ahatamiz commented 2 years ago

Hi @NRSummerfield

I appreciate your comment. Yes, as you correctly mentioned, the architecture follows a pattern of 4, 3, 3, 2. Since UNETR uses a monolithic architecture, choosing which layers to extract features from is not a straight-forward task and requires more analysis. Based on our experiments, a 4, 3, 3, 2 layout resulted in better segmentation performance in the initial experiments.

The MONAI implementation of UNETR closely follows the implementation that was used for the results as presented in the paper, hence allowing for reproducing the presented benchmarks.

Thanks

NRSummerfield commented 2 years ago

Okay, I understand. Thank you very much for your help!