KVishnuVardhanR / M3T

three-dimensional medical image classification using Multi-plane and Multi-slice Transformer
GNU General Public License v3.0
27 stars 2 forks source link

There are some problems lie in your "Extraction of Multi-plane, Multi slice images and 2D CNN block in M3T" part. #6

Closed sleepontheafternoon closed 5 months ago

sleepontheafternoon commented 5 months ago

Thanks for reproducing the code of this paper. Based on your work and my understanding of the paper, I found that you did not use the pre-trained resnet50 model, but only used a normal average pooling layer. This is why your number of parameters is much smaller than the number of parameters in the paper. image Therefore, based on your work, I rewrote this part of the code. The following is the part of the code I rewrote

class MultiPlane_MultiSlice_Extract_Project(nn.Module):
    '''
    The multi-plane and multi-slice image features extraction from the 3D
    representation features X and applying 2D CNN followed by Non-Linear
    Projection
    N = length = width = height based on the mentioned input size in the paper
    Ref: 3.3. Extraction of Multi-plane, Multi slice images and
         3.4. 2D Convolutional Neural Network Block
    '''

    def __init__(self, out_channels: int):
        super(MultiPlane_MultiSlice_Extract_Project, self).__init__()
        # 2D CNN part
        # Use the pretrained resnet50 model and revise it to match the need of the model
        self.CNN_2D = models.resnet50(weights=True)
        self.CNN_2D.conv1 = nn.Conv2d(out_channels,64,kernel_size=7,stride=2,padding=3,bias=False)
        self.CNN_2D.fc = nn.Identity()

        # Non - Linear Projection block
        self.non_linear_proj = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )

    def forward(self, input_tensor):

        B, C, D, H, W = input_tensor.shape

        # input tensor shape:   B 32 128 128 128
        coronal = input_tensor.permute(0, 2, 1, 3, 4).contiguous()   # B 128 32 128 128
        saggital = input_tensor.permute(0, 3, 1, 2, 4).contiguous()  # B 128 32 128 128
        axial = input_tensor.permute(0, 4, 1, 2, 3).contiguous() # B 128 32 128 128

        S = torch.cat([coronal,saggital,axial],dim=1)
        S = S.view(-1,C,H,W).contiguous()  # B*3D C H W   B*3*128 32 128 128

        pooled_feat = self.CNN_2D(S)
        # pooled_feat = torch.flatten(cnn_features_2D,1)
        output_tensor = self.non_linear_proj(pooled_feat)  # Now we have the desired output shape
        return output_tensor.view(B,3*D,-1).contiguous()

After revised it, the total parameters is about 30.28M.

KVishnuVardhanR commented 5 months ago

Hey,

Thank you so much for raising the issue, I have revised the paper once again and made appropriate changes to replicate the code of this paper and I was successful.

The total parameters are about 29.12M

Thank you Vishnu