MIC-DKFZ / HD-BET

MRI brain extraction tool
Apache License 2.0
262 stars 63 forks source link

Regarding application of InstanceNorm3d and F.leaky_relu twice on the input #3

Closed Geeks-Sid closed 5 years ago

Geeks-Sid commented 5 years ago

Hi Fabian,

I noticed something odd in the code. It seems to me that you are applying BatchNormalization followed by a LeakyReLU twice for a given input. It can be spotted by the following behavior. `

class DownsamplingModule(nn.Module):

def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
             lrelu_inplace=True):
    nn.Module.__init__(self)
    self.lrelu_inplace = lrelu_inplace
    self.inst_norm_affine = inst_norm_affine
    self.conv_bias = conv_bias
    self.leakiness = leakiness
    self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
    self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)

def forward(self, x):
    x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
    b = self.downsample(x)
    return x, b

class LocalizationModule(nn.Module):

def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
             lrelu_inplace=True):
    nn.Module.__init__(self)
    self.lrelu_inplace = lrelu_inplace
    self.inst_norm_affine = inst_norm_affine
    self.conv_bias = conv_bias
    self.leakiness = leakiness
    self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
    self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
    self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
    self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)

def forward(self, x):
    x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
    x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
    return x

skip4, x = self.down4(x)
x = torch.cat((skip4, x), dim=1)
x = self.loc1(x)
x = self.up2(x)

` If we see carefully, Downsampling module returns the skip4 after it has applied BatchNormalization and LeakyReLU on its input and then also returns the Downsampled Convolution operation.

Then we later concatenate the skip4 to x and apply BatchNormalization and LeakyReLU on the initial input. But we know skip4 has already been applied with BatchNormalization and LeakyReLU.

I don't understand how it makes sense to apply the input with BatchNormalization and LeakyReLU twice to a given input. Can you please state the reason why you do such an operation?

FabianIsensee commented 5 years ago

Hi,

        skip4, x = self.down4(x)
        x = self.context5(x)

        x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
        x = self.up1(x)

        x = torch.cat((skip4, x), dim=1)
        x = self.loc1(x)

skip4 has instnorm und Leaky ReLU applied to it, yes. When it is concatenated with 'x', x is what the UpsamplingModule spits out. That one also has instnorm and LeakyReLU applied to it, so that's fine, too. The concatenated feature maps are then forwarded into a LocalizationModule:

def forward(self, x):
        x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
        x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
        return x

the first operation applied to x is self.conv1(x), which is what I intended. There is no double application of either instnorm or LReLU. Unless I made a mistake there should be no bug there ;) Hope this helps, Best, Fabian