TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.48k stars 117 forks source link

MACS calculation error when the model structure is nested #60

Closed james20141606 closed 3 years ago

james20141606 commented 3 years ago

Hi, thanks for the tool you provided, very useful. But I also found a bug when I want to calculate each layer's Mult-Addss of a nested model. I got something like this:

Xnip2021-06-27_20-41-37

For most of the layers like TMVANet( (encoder): TMVANet_Encoder( (rd_encoding_branch): EncodingBranch( (double_3dconv_block1): Double3DConvBlock ... I could not get the Mult-Adds information correctly. I assume it is because the block was wrapped several times and could not be handled correctly? Could you please tell me the ways to solve this problem?

The initial part of my model looks like this:

network
james20141606 commented 3 years ago

Seems that the total Mult-Adds are way bigger than the sum of the few layers' Mult-Adds. Maybe the program do calculate each layer's Mult-Adds but did not display them correctly?

TylerYep commented 3 years ago

From torchinfo's output, it looks like it is adding the sum of the mult-adds in the rightmost column correctly to form the total at the bottom. Is there a specific layer that does not look correct? The math looks approximately correct to me.

james20141606 commented 3 years ago

Hi, thanks for the response. I don't think the sum up of the rightmost column is the total at the bottom. The sum is less than 40G but the total is 290 G. I believe the total number is correct as I benchmarked some popular models and the MACs are almost the same as the results reported in other places. But I am afraid there might be some displaying issue for every single layer. The desired summary should display each layer's info. I could have that if I only use the first layer. But it will be very inconvenient if I want the information for each layer. Could you check again the displaying issue? Thanks a lot!

Screen Shot 2021-07-02 at 1 39 17 PM
james20141606 commented 3 years ago
Screen Shot 2021-07-02 at 6 25 50 PM

By the way, there is another small problem that if the macs is less than 1G, the total mult-adds will not automatically switch to M but still has the unit of G and it is not correct.

TylerYep commented 3 years ago

Can you post your full code as text? I believe this issue should be fixed in torchinfo>=1.5.0, which I just released.

james20141606 commented 3 years ago

sure. The codes to reproduce:

import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConvBlock(nn.Module):
    """ (2D conv => BN => LeakyReLU) * 2 """

    def __init__(self, in_ch, out_ch, k_size, pad, dil):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k_size, padding=pad, dilation=dil),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=k_size, padding=pad, dilation=dil),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        x = self.block(x)
        return x

class Double3DConvBlock(nn.Module):
    """ (3D conv => BN => LeakyReLU) * 2 """

    def __init__(self, in_ch, out_ch, k_size, pad, dil):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=k_size, padding=pad, dilation=dil),
            nn.BatchNorm3d(out_ch),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size=k_size, padding=pad, dilation=dil),
            nn.BatchNorm3d(out_ch),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        x = self.block(x)
        return x

class ConvBlock(nn.Module):
    """ (2D conv => BN => LeakyReLU) """

    def __init__(self, in_ch, out_ch, k_size, pad, dil):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k_size, padding=pad, dilation=dil),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        x = self.block(x)
        return x

class ASPPBlock(nn.Module):
    """Atrous Spatial Pyramid Pooling
    Parallel conv blocks with different dilation rate
    """

    def __init__(self, in_ch, out_ch=256):
        super().__init__()
        self.global_avg_pool = nn.AvgPool2d((64, 64))
        self.conv1_1x1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0, dilation=1)
        self.single_conv_block1_1x1 = ConvBlock(in_ch, out_ch, k_size=1, pad=0, dil=1)
        self.single_conv_block1_3x3 = ConvBlock(in_ch, out_ch, k_size=3, pad=6, dil=6)
        self.single_conv_block2_3x3 = ConvBlock(in_ch, out_ch, k_size=3, pad=12, dil=12)
        self.single_conv_block3_3x3 = ConvBlock(in_ch, out_ch, k_size=3, pad=18, dil=18)

    def forward(self, x):
        x1 = F.interpolate(self.global_avg_pool(x), size=(64, 64), align_corners=False,
                           mode='bilinear')
        x1 = self.conv1_1x1(x1)
        x2 = self.single_conv_block1_1x1(x)
        x3 = self.single_conv_block1_3x3(x)
        x4 = self.single_conv_block2_3x3(x)
        x5 = self.single_conv_block3_3x3(x)
        x_cat = torch.cat((x2, x3, x4, x5, x1), 1)
        return x_cat

class EncodingBranch(nn.Module):
    """
    Encoding branch for a single radar view

    PARAMETERS
    ----------
    signal_type: str
        Type of radar view.
        Supported: 'range_doppler', 'range_angle' and 'angle_doppler'
    """

    def __init__(self, signal_type):
        super().__init__()
        self.signal_type = signal_type
        self.double_3dconv_block1 = Double3DConvBlock(in_ch=1, out_ch=128, k_size=3,
                                                      pad=(0, 1, 1), dil=1)
        self.doppler_max_pool = nn.MaxPool2d(2, stride=(2, 1))
        self.max_pool = nn.MaxPool2d(2, stride=2)
        self.double_conv_block2 = DoubleConvBlock(in_ch=128, out_ch=128, k_size=3,
                                                  pad=1, dil=1)
        self.single_conv_block1_1x1 = ConvBlock(in_ch=128, out_ch=128, k_size=1,
                                                pad=0, dil=1)

    def forward(self, x):
        x1 = self.double_3dconv_block1(x)
        x1 = torch.squeeze(x1, 2)  # remove temporal dimension

        if self.signal_type in ('range_doppler', 'angle_doppler'):
            # The Doppler dimension requires a specific processing
            x1_pad = F.pad(x1, (0, 1, 0, 0), "constant", 0)
            x1_down = self.doppler_max_pool(x1_pad)
        else:
            x1_down = self.max_pool(x1)

        x2 = self.double_conv_block2(x1_down)
        if self.signal_type in ('range_doppler', 'angle_doppler'):
            # The Doppler dimension requires a specific processing
            x2_pad = F.pad(x2, (0, 1, 0, 0), "constant", 0)
            x2_down = self.doppler_max_pool(x2_pad)
        else:
            x2_down = self.max_pool(x2)

        x3 = self.single_conv_block1_1x1(x2_down)
        # return input of ASPP block + latent features
        return x2_down, x3

class TMVANet_Encoder(nn.Module):
    """ 
    Temporal Multi-View with ASPP Network (TMVA-Net)

    PARAMETERS
    ----------
    n_classes: int
        Number of classes used for the semantic segmentation task
    n_frames: int
        Total numer of frames used as a sequence
    """

    def __init__(self, n_classes, n_frames):
        super().__init__()
        self.n_classes = n_classes
        self.n_frames = n_frames

        # Backbone (encoding)
        self.rd_encoding_branch = EncodingBranch('range_doppler')
        self.ra_encoding_branch = EncodingBranch('range_angle')
        self.ad_encoding_branch = EncodingBranch('angle_doppler')

        # ASPP Blocks
        self.rd_aspp_block = ASPPBlock(in_ch=128, out_ch=128)
        self.ra_aspp_block = ASPPBlock(in_ch=128, out_ch=128)
        self.ad_aspp_block = ASPPBlock(in_ch=128, out_ch=128)
        self.rd_single_conv_block1_1x1 = ConvBlock(in_ch=640, out_ch=128, k_size=1, pad=0, dil=1)
        self.ra_single_conv_block1_1x1 = ConvBlock(in_ch=640, out_ch=128, k_size=1, pad=0, dil=1)
        self.ad_single_conv_block1_1x1 = ConvBlock(in_ch=640, out_ch=128, k_size=1, pad=0, dil=1)

    def forward(self, x_rd, x_ra, x_ad,printshape=False):
        # Backbone
        ra_features, ra_latent = self.ra_encoding_branch(x_ra)
        rd_features, rd_latent = self.rd_encoding_branch(x_rd)
        ad_features, ad_latent = self.ad_encoding_branch(x_ad)
        if printshape:
            print ('ra_features, ra_latent', ra_features.shape, ra_latent.shape)

        # ASPP blocks
        x1_rd = self.rd_aspp_block(rd_features)
        x1_ra = self.ra_aspp_block(ra_features)
        x1_ad = self.ad_aspp_block(ad_features)
        x2_rd = self.rd_single_conv_block1_1x1(x1_rd)
        x2_ra = self.ra_single_conv_block1_1x1(x1_ra)
        x2_ad = self.ad_single_conv_block1_1x1(x1_ad)
        if printshape:
            print ('x1_rd, x2_rd', x1_rd.shape, x2_rd.shape)

        # Features join either the RD or the RA branch
        x3 = torch.cat((rd_latent, ra_latent, ad_latent), 1)

        return x3, x2_rd, x2_ad, x2_ra

class TMVANet_Decoder(nn.Module):
    """ 
    Temporal Multi-View with ASPP Network (TMVA-Net)

    PARAMETERS
    ----------
    n_classes: int
        Number of classes used for the semantic segmentation task
    n_frames: int
        Total numer of frames used as a sequence
    """

    def __init__(self, n_classes, n_frames):
        super().__init__()
        self.n_classes = n_classes
        self.n_frames = n_frames

        # Decoding
        self.rd_single_conv_block2_1x1 = ConvBlock(in_ch=384, out_ch=128, k_size=1, pad=0, dil=1)
        self.ra_single_conv_block2_1x1 = ConvBlock(in_ch=384, out_ch=128, k_size=1, pad=0, dil=1)

        # Pallel range-Doppler (RD) and range-angle (RA) decoding branches
        self.rd_upconv1 = nn.ConvTranspose2d(384, 128, (2, 1), stride=(2, 1))
        self.ra_upconv1 = nn.ConvTranspose2d(384, 128, 2, stride=2)
        self.rd_double_conv_block1 = DoubleConvBlock(in_ch=128, out_ch=128, k_size=3,
                                                     pad=1, dil=1)
        self.ra_double_conv_block1 = DoubleConvBlock(in_ch=128, out_ch=128, k_size=3,
                                                     pad=1, dil=1)
        self.rd_upconv2 = nn.ConvTranspose2d(128, 128, (2, 1), stride=(2, 1))
        self.ra_upconv2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
        self.rd_double_conv_block2 = DoubleConvBlock(in_ch=128, out_ch=128, k_size=3,
                                                     pad=1, dil=1)
        self.ra_double_conv_block2 = DoubleConvBlock(in_ch=128, out_ch=128, k_size=3,
                                                     pad=1, dil=1)

        # Final 1D convs
        self.rd_final = nn.Conv2d(in_channels=128, out_channels=n_classes, kernel_size=1)
        self.ra_final = nn.Conv2d(in_channels=128, out_channels=n_classes, kernel_size=1)

    def forward(self, x3, x2_rd, x2_ad, x2_ra , printshape=False):
        # Parallel decoding branches with upconvs

        # Latent Space

        x3_rd = self.rd_single_conv_block2_1x1(x3)
        x3_ra = self.ra_single_conv_block2_1x1(x3)
        if printshape:
            print ('x3, x3_rd', x3.shape, x3_rd.shape)

        # Latent Space + ASPP features
        x4_rd = torch.cat((x2_rd, x3_rd, x2_ad), 1)
        x4_ra = torch.cat((x2_ra, x3_ra, x2_ad), 1)
        if printshape:
            print ('x4_rd, cat(x2_rd, x3_rd, x2_ad)',x4_rd.shape,x2_rd.shape, x3_rd.shape, x2_ad.shape)

        x5_rd = self.rd_upconv1(x4_rd)
        x5_ra = self.ra_upconv1(x4_ra)
        x6_rd = self.rd_double_conv_block1(x5_rd)
        x6_ra = self.ra_double_conv_block1(x5_ra)
        if printshape:
            print ('x5_rd, x6_rd',x5_rd.shape, x6_rd.shape)

        x7_rd = self.rd_upconv2(x6_rd)
        x7_ra = self.ra_upconv2(x6_ra)
        x8_rd = self.rd_double_conv_block2(x7_rd)
        x8_ra = self.ra_double_conv_block2(x7_ra)
        if printshape:
            print ('x7_rd, x8_rd',x7_rd.shape, x8_rd.shape)

        # Final 1D convolutions
        x9_rd = self.rd_final(x8_rd)
        x9_ra = self.ra_final(x8_ra)

        return x9_rd, x9_ra
class TMVANet(nn.Module):
    """ 
    Temporal Multi-View with ASPP Network (TMVA-Net)

    PARAMETERS
    ----------
    n_classes: int
        Number of classes used for the semantic segmentation task
    n_frames: int
        Total numer of frames used as a sequence
    """

    def __init__(self, n_classes, n_frames):
        super().__init__()
        self.n_classes = n_classes
        self.n_frames = n_frames

        self.encoder = TMVANet_Encoder(n_classes, n_frames)
        self.decoder = TMVANet_Decoder(n_classes, n_frames)

    def forward(self, x_rd, x_ra, x_ad,printshape=False):
        x3, x2_rd, x2_ad, x2_ra = self.encoder(x_rd, x_ra, x_ad)
        x9_rd, x9_ra = self.decoder(x3, x2_rd, x2_ad, x2_ra )
        return x9_rd, x9_ra

from torchinfo import summary
net = TMVANet(n_classes=4, n_frames=5)
summary(net,input_data=[torch.randn(1, 1, 5, 256, 64),torch.randn(1, 1, 5,256,256),torch.randn(1, 1, 5,256,64)],verbose=0,col_names=[#"input_size",
                "output_size", "num_params","mult_adds"])

I tried to update torchinfo but seems still have the issue

james20141606 commented 3 years ago

yes, that is correct for sure since I break it down to a single sequential module it works well. The problem is in the very first figure some layers don't have the macs information and the sum is not the total macs. So I doubt there are something wrong if the layers are wrapped several times. For example some layers like

TMVANet(
(encoder): TMVANet_Encoder(
(rd_encoding_branch): EncodingBranch(
(double_3dconv_block1): Double3DConvBlock

then the Double3DConvBlock layer does not have macs information. So the sum of the displayed layers doesn't match the total macs. I believe the total macs are calculated correctly, so there might be some problem with the display function.

TylerYep commented 3 years ago

Thank you for the thorough bug report! I think I've figured out the issue and a fix in https://github.com/TylerYep/torchinfo/commit/fc6e7c13746ea3e539d90baafd852dd5665b5341.

Please let me know if this solves the issue! If so, I'll submit a new release this weekend.

james20141606 commented 3 years ago

Hi, I updated the codes and found that each layer has its macs displayed. But I am afraid that the information might not be correct. When I sum up the macs:

model_stats = summary(net,input_data=[torch.randn(1, 1, 5, 256, 64),torch.randn(1, 1, 5,256,256),torch.randn(1, 1, 5,256,64)],verbose=0,col_names=[#"input_size",
                "output_size", "num_params","mult_adds"])
summary_str = str(model_stats)
np.array([int(i.split('    ')[-1].replace(',','')) for i in summary_str.split('\nTotal pa')[0].split('\n') \
         if (i.startswith('│')) and (not i.endswith('--'))]).sum()/1000**3

I found the total macs is 9.781067776 but the total macs torchinfo displayed in the end is Total mult-adds (G): 102.04.

TylerYep commented 3 years ago

Currently working on a followup fix to this issue.

I will continue to test this functionality on parameters and mult-adds before the next release. Thank you for your help in tracking these issues down!

TylerYep commented 3 years ago

This is resolved in version 1.5.1! I generalized this display testing to work for all test models and fixed several other bugs in the process. Please let me know if you encounter any other issues, either by reopening or creating a new issue.