AghdamAmir / 3D-UNet

A pytorch implementation of 3D UNet for 3D MRI Segmentation.
45 stars 9 forks source link

UpConv3DBlock as well as activation function #1

Closed jizhang02 closed 1 year ago

jizhang02 commented 1 year ago

Hello, Amir, I see your code is pretty good, it looks clean and brief,

class UpConv3DBlock(nn.Module):

    def __init__(self, in_channels, res_channels=0, last_layer=False, num_classes=None) -> None:
        super(UpConv3DBlock, self).__init__()
        assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
        self.upconv1 = nn.ConvTranspose3d(in_channels=in_channels, out_channels=in_channels, kernel_size=(2, 2, 2), stride=2)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm3d(num_features=in_channels//2)
        self.conv1 = nn.Conv3d(in_channels=in_channels+res_channels, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv2 = nn.Conv3d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
        self.last_layer = last_layer
        if last_layer:
            self.conv3 = nn.Conv3d(in_channels=in_channels//2, out_channels=num_classes, kernel_size=(1,1,1))

    def forward(self, input, residual=None):
        out = self.upconv1(input)
        if residual!=None: out = torch.cat((out, residual), 1)
        out = self.relu(self.bn(self.conv1(out)))
        out = self.relu(self.bn(self.conv2(out)))
        if self.last_layer: out = self.conv3(out)
        return out

and in this part, I don't quite understand the sentence assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments' why do you emphasize the last_layer in this function?

Also, in the last layer, I didn't see the activation function you used.

so, I added a sentence, maybe it's more complete, right?

class UNet3D(nn.Module):

    def __init__(self, in_channels, num_classes, level_channels=[32, 64, 128], bottleneck_channel=256) -> None:
        super(UNet3D, self).__init__()
        level_1_chnls, level_2_chnls, level_3_chnls = level_channels[0], level_channels[1], level_channels[2]
        self.a_block1 = Conv3DBlock(in_channels=in_channels, out_channels=level_1_chnls)
        self.a_block2 = Conv3DBlock(in_channels=level_1_chnls, out_channels=level_2_chnls)
        self.a_block3 = Conv3DBlock(in_channels=level_2_chnls, out_channels=level_3_chnls)
        self.bottleNeck = Conv3DBlock(in_channels=level_3_chnls, out_channels=bottleneck_channel, bottleneck= True)
        self.s_block3 = UpConv3DBlock(in_channels=bottleneck_channel, res_channels=level_3_chnls)
        self.s_block2 = UpConv3DBlock(in_channels=level_3_chnls, res_channels=level_2_chnls)
        self.s_block1 = UpConv3DBlock(in_channels=level_2_chnls, res_channels=level_1_chnls, num_classes=num_classes, last_layer=True)
        self.final_activation = nn.Sigmoid() # here

    def forward(self, input):
        #Analysis path forward feed
        out, residual_level1 = self.a_block1(input)
        out, residual_level2 = self.a_block2(out)
        out, residual_level3 = self.a_block3(out)
        out, _ = self.bottleNeck(out)

        #Synthesis path forward feed
        out = self.s_block3(out, residual_level3)
        out = self.s_block2(out, residual_level2)
        out = self.s_block1(out, residual_level1)
        out = self.final_activation(out) # and here
        return out

Thank you! Jing

AghdamAmir commented 1 year ago

Hello, Amir, I see your code is pretty good, it looks clean and brief,

class UpConv3DBlock(nn.Module):

    def __init__(self, in_channels, res_channels=0, last_layer=False, num_classes=None) -> None:
        super(UpConv3DBlock, self).__init__()
        assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
        self.upconv1 = nn.ConvTranspose3d(in_channels=in_channels, out_channels=in_channels, kernel_size=(2, 2, 2), stride=2)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm3d(num_features=in_channels//2)
        self.conv1 = nn.Conv3d(in_channels=in_channels+res_channels, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv2 = nn.Conv3d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
        self.last_layer = last_layer
        if last_layer:
            self.conv3 = nn.Conv3d(in_channels=in_channels//2, out_channels=num_classes, kernel_size=(1,1,1))

    def forward(self, input, residual=None):
        out = self.upconv1(input)
        if residual!=None: out = torch.cat((out, residual), 1)
        out = self.relu(self.bn(self.conv1(out)))
        out = self.relu(self.bn(self.conv2(out)))
        if self.last_layer: out = self.conv3(out)
        return out

and in this part, I don't quite understand the sentence assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments' why do you emphasize the last_layer in this function?

Also, in the last layer, I didn't see the activation function you used.

so, I added a sentence, maybe it's more complete, right?

class UNet3D(nn.Module):

    def __init__(self, in_channels, num_classes, level_channels=[32, 64, 128], bottleneck_channel=256) -> None:
        super(UNet3D, self).__init__()
        level_1_chnls, level_2_chnls, level_3_chnls = level_channels[0], level_channels[1], level_channels[2]
        self.a_block1 = Conv3DBlock(in_channels=in_channels, out_channels=level_1_chnls)
        self.a_block2 = Conv3DBlock(in_channels=level_1_chnls, out_channels=level_2_chnls)
        self.a_block3 = Conv3DBlock(in_channels=level_2_chnls, out_channels=level_3_chnls)
        self.bottleNeck = Conv3DBlock(in_channels=level_3_chnls, out_channels=bottleneck_channel, bottleneck= True)
        self.s_block3 = UpConv3DBlock(in_channels=bottleneck_channel, res_channels=level_3_chnls)
        self.s_block2 = UpConv3DBlock(in_channels=level_3_chnls, res_channels=level_2_chnls)
        self.s_block1 = UpConv3DBlock(in_channels=level_2_chnls, res_channels=level_1_chnls, num_classes=num_classes, last_layer=True)
        self.final_activation = nn.Sigmoid() # here

    def forward(self, input):
        #Analysis path forward feed
        out, residual_level1 = self.a_block1(input)
        out, residual_level2 = self.a_block2(out)
        out, residual_level3 = self.a_block3(out)
        out, _ = self.bottleNeck(out)

        #Synthesis path forward feed
        out = self.s_block3(out, residual_level3)
        out = self.s_block2(out, residual_level2)
        out = self.s_block1(out, residual_level1)
        out = self.final_activation(out) # and here
        return out

Thank you! Jing

Hi Jing,

Thank you for taking the time to review my code. I'm glad that you found it clean and brief.

Regarding your question about the assert statement in the UpConv3DBlock class, the purpose of that statement is to ensure that the arguments passed to the constructor are valid. The last_layer argument is used to indicate whether the current block is the last block in the network. If last_layer is True, then the num_classes argument must be provided to specify the number of classes in the output. If last_layer is False, then num_classes should not be provided. The assert statement checks whether this condition is met and raises an error if it is not. This is done to prevent the user from making a mistake in specifying the arguments.

As for your addition of the final_activation attribute, you're absolutely right about the activation function. However, a number of widely-used segmentation loss functions, such as CrossEntropyLoss, accept logits as input. This is why I haven't added any activation functions to my implementation.

In case you need to use the output of the network for something other than passing it through a loss function, you can add an activation function at the end, just as you have done. To make it even cleaner, you can add another argument to the constructor of the class that specifies whether the output should be logits or probabilities.

I hope this clears up any confusion. Let me know if you have any further questions or concerns. Thank you again for your feedback.

Best regards, Amir

jizhang02 commented 1 year ago

Hi Amir,

Yes, I've understood, it's much clearer.
Thank you very much!