Rubics-Xuan / TransBTS

This repo provides the official code for : 1) TransBTS: Multimodal Brain Tumor Segmentation Using Transformer (https://arxiv.org/abs/2103.04430) , accepted by MICCAI2021. 2) TransBTSV2: Towards Better and More Efficient Volumetric Segmentation of Medical Images(https://arxiv.org/abs/2201.12785).
Apache License 2.0
388 stars 81 forks source link

Some question about the code #4

Closed chengjianhong closed 3 years ago

chengjianhong commented 3 years ago

Hi, Great work. But there are few small problems that puzzle me. 1) The image shape changes from 240240155 to 240240160,what are the considerations here? 2) I can't understand the intmd_encoder_outputs of the encoder which is the output intmd_x of transformer, what's the difference between xand intmd_x? Besides, the encoder_output seems to be not used in Decoder.

x, intmd_x = self.transformer(x)

 def forward(self, x, auxillary_output_layers=[1, 2, 3, 4]):

        x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs = self.encode(x)

        decoder_output = self.decode(
            x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs, auxillary_output_layers
        )

        if auxillary_output_layers is not None:
            auxillary_outputs = {}
            for i in auxillary_output_layers:
                val = str(2 * i - 1)
                _key = 'Z' + str(i)
                auxillary_outputs[_key] = intmd_encoder_outputs[val]

            return decoder_output

        return decoder_output
Rubics-Xuan commented 3 years ago

Thanks for your questions. Firstly, the padding operation is not necessarily needed in the code , since the Random_Crop() exits , we don't need to consider that 155 is not devidable by OS=8.

Secondly, x is the output of the Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))) layer of the last transformer layer(i.e. the 4th transformer) , while intmd_encoder_outputs include output of the each layer(i.e.

  1. Residual( PreNormDrop( dim, dropout_rate, SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate), )
    1. Residual( PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)) ) ) of all 4 transformer layer (i.e. 1th,2th,3th,4th). So the difference is easy to tell.

In this part of code, you can tell that the encoder_output is surely used in Decoder. decoder_output = self.decode( x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs, auxillary_output_layers )

chengjianhong commented 3 years ago

the encoder_output is input into the x of the decoder. However, x doesn't be used in the decoder. Does the x be same with encoder_outputs[all_keys[0]]?

def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, 4]):

        assert intmd_layers is not None, "pass the intermediate layers for MLA"
        encoder_outputs = {}
        all_keys = []
        for i in intmd_layers:
            val = str(2 * i - 1)
            _key = 'Z' + str(i)
            all_keys.append(_key)
            encoder_outputs[_key] = intmd_x[val]
        all_keys.reverse()

        x8 = encoder_outputs[all_keys[0]]
        x8 = self._reshape_output(x8)
        x8 = self.Enblock8_1(x8)
        x8 = self.Enblock8_2(x8)

        y4 = self.DeUp4(x8, x3_1)  # (1, 64, 32, 32, 32)
        y4 = self.DeBlock4(y4)

        y3 = self.DeUp3(y4, x2_1)  # (1, 32, 64, 64, 64)
        y3 = self.DeBlock3(y3)

        y2 = self.DeUp2(y3, x1_1)  # (1, 16, 128, 128, 128)
        y2 = self.DeBlock2(y2)

        y = self.endconv(y2)      # (1, 4, 128, 128, 128)
        y = self.Softmax(y)
        return y