jeya-maria-jose / Medical-Transformer

Official Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation" - MICCAI 2021
MIT License
791 stars 176 forks source link

Question about 'medtnet' #50

Closed JackHeroFly closed 2 years ago

JackHeroFly commented 2 years ago

Hello author. Thanks for your code. Do you just calculate the attention on the patch image whose size is 32x32 and attention between patches are not considerded? Look at the following codes. "for i in range(0,4): for j in range(0,4): x_p = xin[:,:,32i:32(i+1),32j:32(j+1)] #4x1x32x32

begin patch wise

            x_p = self.conv1_p(x_p)
            x_p = self.bn1_p(x_p)
            # x = F.max_pool2d(x,2,2)
            x_p = self.relu(x_p) #4x64x16x16

            x_p = self.conv2_p(x_p) #4x128x16x16
            x_p = self.bn2_p(x_p)
            # x = F.max_pool2d(x,2,2)
            x_p = self.relu(x_p)
            x_p = self.conv3_p(x_p)#4x64x16x16
            x_p = self.bn3_p(x_p)
            # x = F.max_pool2d(x,2,2)
            x_p = self.relu(x_p)

            # x = self.maxpool(x)
            # pdb.set_trace()
            x1_p = self.layer1_p(x_p) #1 #4x32x16x16
            # print(x1.shape)
            x2_p = self.layer2_p(x1_p) #2 #4x64x8x8
            # print(x2.shape)
            x3_p = self.layer3_p(x2_p) #4 #4x128x4x4
            # # print(x3.shape)
            x4_p = self.layer4_p(x3_p) #1  #4x256x2x2

            x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2,2), mode ='bilinear'))#4x256x2x2
            x_p = torch.add(x_p, x4_p)
            x_p = F.relu(F.interpolate(self.decoder2_p(x_p) , scale_factor=(2,2), mode ='bilinear'))#4x128x4x4
            x_p = torch.add(x_p, x3_p)
            x_p = F.relu(F.interpolate(self.decoder3_p(x_p) , scale_factor=(2,2), mode ='bilinear'))#4x64x8x8
            x_p = torch.add(x_p, x2_p)
            x_p = F.relu(F.interpolate(self.decoder4_p(x_p) , scale_factor=(2,2), mode ='bilinear'))#4x32x16x16
            x_p = torch.add(x_p, x1_p)
            x_p = F.relu(F.interpolate(self.decoder5_p(x_p) , scale_factor=(2,2), mode ='bilinear'))#4x16x32x32

            x_loc[:,:,32*i:32*(i+1),32*j:32*(j+1)] = x_p"
jeya-maria-jose commented 2 years ago

Hi,

Yes, we did not consider the attention in between patches. MedT is not a sequence to sequence architecture like ViT. We show that even without that we get a good performance without using any pretrained weights.