TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.56k stars 119 forks source link

Resnet hierarchy get wrong #194

Closed BigDevil82 closed 1 year ago

BigDevil82 commented 1 year ago

almost the same code as in #105 , but with a different problem the problem lies in the summary hierarchy of ResNetBlock

this is the Generator code :

class GlobalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 
                 padding_type='reflect'):
        assert(n_blocks >= 0)
        super(GlobalGenerator, self).__init__()        
        activation = nn.ReLU(True)        

        model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
        ### downsample
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                      norm_layer(ngf * mult * 2), activation]

        ### resnet blocks
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]

        ### upsample         
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                       norm_layer(int(ngf * mult / 2)), activation]
        model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]        
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)        

and this is resnet_block :

class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, activation=None, use_dropout=False):
        super(ResnetBlock, self).__init__()
        activation =  nn.ReLU(True) if activation is None else activation
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)

    def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim),
                       activation]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

the output summary about ResNetBlock has a wrong structure, as shown in the screenshot below:

image

the single ResNetBlock is supposed to contain these basic modules in a sequential container:

image

but when it comes to ReLU, it split into another ResNetBlock, is it a bug? thanks

mert-kurttutan commented 1 year ago

It is essentially because of the way parameterless (e.g. ReLU) are used and the way torchinfo records them. In general, one creates one object for activation layer and reuse the same layer everywhere it is needed, which is very common.

So, in your case the first time ReLU that is processed, is └─ReLU: 2-4 after first batch which has depth 2. Since this layer is used for the rest of the model, this info keeps getting used for other ReLU layers. This is why other ReLU layers have depth 2, which leads to separation of block (which has depth 3).

In your case to resolve this, what you can do is to replace every activation with a deepcopy version, see below code


class GlobalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 
                 padding_type='reflect'):
        assert(n_blocks >= 0)
        super(GlobalGenerator, self).__init__()        
        activation = nn.ReLU(True)        

        model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), deepcopy(activation)]
        ### downsample
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                      norm_layer(ngf * mult * 2), deepcopy(activation)]

        ### resnet blocks
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=deepcopy(activation), norm_layer=norm_layer)]

        ### upsample         
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                       norm_layer(int(ngf * mult / 2)), deepcopy(activation)]
        model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), deepcopy(nn.Tanh())]        
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)     

In this case, you will get the desired output. I think this is a feature that can be improved. Main way to use is to renew info of parameterless layers whenever it is used (even when it is reused). Or another way would be to make info of parameterless layers dependent on its output/input.

BigDevil82 commented 1 year ago

Got it, thanks!