Closed jizhang02 closed 1 year ago
Hello, Amir, I see your code is pretty good, it looks clean and brief,
class UpConv3DBlock(nn.Module): def __init__(self, in_channels, res_channels=0, last_layer=False, num_classes=None) -> None: super(UpConv3DBlock, self).__init__() assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments' self.upconv1 = nn.ConvTranspose3d(in_channels=in_channels, out_channels=in_channels, kernel_size=(2, 2, 2), stride=2) self.relu = nn.ReLU() self.bn = nn.BatchNorm3d(num_features=in_channels//2) self.conv1 = nn.Conv3d(in_channels=in_channels+res_channels, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1)) self.conv2 = nn.Conv3d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1)) self.last_layer = last_layer if last_layer: self.conv3 = nn.Conv3d(in_channels=in_channels//2, out_channels=num_classes, kernel_size=(1,1,1)) def forward(self, input, residual=None): out = self.upconv1(input) if residual!=None: out = torch.cat((out, residual), 1) out = self.relu(self.bn(self.conv1(out))) out = self.relu(self.bn(self.conv2(out))) if self.last_layer: out = self.conv3(out) return out
and in this part, I don't quite understand the sentence
assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
why do you emphasize thelast_layer
in this function?Also, in the last layer, I didn't see the activation function you used.
so, I added a sentence, maybe it's more complete, right?
class UNet3D(nn.Module): def __init__(self, in_channels, num_classes, level_channels=[32, 64, 128], bottleneck_channel=256) -> None: super(UNet3D, self).__init__() level_1_chnls, level_2_chnls, level_3_chnls = level_channels[0], level_channels[1], level_channels[2] self.a_block1 = Conv3DBlock(in_channels=in_channels, out_channels=level_1_chnls) self.a_block2 = Conv3DBlock(in_channels=level_1_chnls, out_channels=level_2_chnls) self.a_block3 = Conv3DBlock(in_channels=level_2_chnls, out_channels=level_3_chnls) self.bottleNeck = Conv3DBlock(in_channels=level_3_chnls, out_channels=bottleneck_channel, bottleneck= True) self.s_block3 = UpConv3DBlock(in_channels=bottleneck_channel, res_channels=level_3_chnls) self.s_block2 = UpConv3DBlock(in_channels=level_3_chnls, res_channels=level_2_chnls) self.s_block1 = UpConv3DBlock(in_channels=level_2_chnls, res_channels=level_1_chnls, num_classes=num_classes, last_layer=True) self.final_activation = nn.Sigmoid() # here def forward(self, input): #Analysis path forward feed out, residual_level1 = self.a_block1(input) out, residual_level2 = self.a_block2(out) out, residual_level3 = self.a_block3(out) out, _ = self.bottleNeck(out) #Synthesis path forward feed out = self.s_block3(out, residual_level3) out = self.s_block2(out, residual_level2) out = self.s_block1(out, residual_level1) out = self.final_activation(out) # and here return out
Thank you! Jing
Hi Jing,
Thank you for taking the time to review my code. I'm glad that you found it clean and brief.
Regarding your question about the assert statement in the UpConv3DBlock
class, the purpose of that statement is to ensure that the arguments passed to the constructor are valid. The last_layer
argument is used to indicate whether the current block is the last block in the network. If last_layer
is True, then the num_classes
argument must be provided to specify the number of classes in the output. If last_layer
is False, then num_classes
should not be provided. The assert statement checks whether this condition is met and raises an error if it is not. This is done to prevent the user from making a mistake in specifying the arguments.
As for your addition of the final_activation
attribute, you're absolutely right about the activation function. However, a number of widely-used segmentation loss functions, such as CrossEntropyLoss
, accept logits as input. This is why I haven't added any activation functions to my implementation.
In case you need to use the output of the network for something other than passing it through a loss function, you can add an activation function at the end, just as you have done. To make it even cleaner, you can add another argument to the constructor of the class that specifies whether the output should be logits or probabilities.
I hope this clears up any confusion. Let me know if you have any further questions or concerns. Thank you again for your feedback.
Best regards, Amir
Hi Amir,
Yes, I've understood, it's much clearer.
Thank you very much!
Hello, Amir, I see your code is pretty good, it looks clean and brief,
and in this part, I don't quite understand the sentence
assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
why do you emphasize thelast_layer
in this function?Also, in the last layer, I didn't see the activation function you used.
so, I added a sentence, maybe it's more complete, right?
Thank you! Jing