black0017 / MedicalZooPytorch

A pytorch-based deep learning framework for multi-modal 2D/3D medical image segmentation
MIT License
1.72k stars 299 forks source link

The OutputTransition of the VNet model exists with one more than Conv3D #30

Open XmySz opened 1 year ago

XmySz commented 1 year ago

According to the structure diagram of the paper, the final OutputTransition should look like the following:

class OutputTransition(nn.Module):
    def __init__(self, in_channels, classes, elu):
        super(OutputTransition, self).__init__()
        self.classes = classes
        # self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)    # 修改
        self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=1)
        self.bn1 = torch.nn.BatchNorm3d(classes)

        self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)
        self.relu1 = ELUCons(elu, classes)

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        # out = self.conv2(out) # 修改
        return out

We simply use the only 111 convolutional layer to make the number of channels the same as the classes.