TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.56k stars 119 forks source link

Wrong recursive tag and duplicated/unordered entries #157

Closed CappellatoAlessio closed 2 years ago

CappellatoAlessio commented 2 years ago

Describe the bug summary wrongly reports modules as recursive and several entries are duplicated and not in any logical order.

To Reproduce Steps to reproduce the behavior:

  1. conda create -n recursivetest -c defaults -c pytorch -c conda-forge python pytorch torchvision torchaudio [cpuonly] torchinfo
  2. conda activate recursivetest
  3. Save recursivetest.py file (minimum example):
    
    import torch
    from torch import nn
    from torchinfo import summary

class RecursiveTest(nn.Module): def init(self): super().init() self.out_conv0 = nn.Conv2d(3, 8, 5, padding='same') self.out_bn0 = nn.BatchNorm2d(8)

    self.block0 = nn.ModuleDict()
    for i in range(1, 4):
        self.block0.add_module(f"in_conv{i}", nn.Conv2d(8, 8, 3, padding="same", dilation=2 ** i))
        self.block0.add_module(f"in_bn{i}", nn.BatchNorm2d(8))

    self.block1 = nn.ModuleDict()
    for i in range(4, 7):
        self.block1.add_module(f"in_conv{i}", nn.Conv2d(8, 8, 3, padding="same", dilation=2 ** (7 - i)))
        self.block1.add_module(f"in_bn{i}", nn.BatchNorm2d(8))

    self.out_conv7 = nn.Conv2d(8, 1, 1, padding='same')
    self.out_bn7 = nn.BatchNorm2d(1)

def forward(self, x):
    x = self.out_conv0(x)
    x = torch.relu(self.out_bn0(x))

    for i in range(1, 4):
        x = self.block0[f"in_conv{i}"](x)
        x = torch.relu(self.block0[f"in_bn{i}"](x))

    for i in range(4, 7):
        x = self.block1[f"in_conv{i}"](x)
        x = torch.relu(self.block1[f"in_bn{i}"](x))

    x = self.out_conv7(x)
    x = torch.relu(self.out_bn7(x))
    return x

if name == 'main': batch_size = 2 data_shape = (3, 128, 128) random_data = torch.rand((batch_size, *data_shape)) my_nn = RecursiveTest() print(my_nn) summary(my_nn, input_data=[random_data], row_settings=('depth', 'var_names'))

4. Run `python recursivetest.py`
5. See output:

RecursiveTest( (out_conv0): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1), padding=same) (out_bn0): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (block0): ModuleDict( (in_conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(2, 2)) (in_bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (in_conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(4, 4)) (in_bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (in_conv3): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(8, 8)) (in_bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (block1): ModuleDict( (in_conv4): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(8, 8)) (in_bn4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (in_conv5): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(4, 4)) (in_bn5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (in_conv6): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(2, 2)) (in_bn6): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (out_conv7): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1), padding=same) (out_bn7): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )

Layer (type (var_name):depth-idx) Output Shape Param #

RecursiveTest (RecursiveTest) [2, 1, 128, 128] -- ├─Conv2d (out_conv0): 1-5 [2, 8, 128, 128] (recursive) ├─BatchNorm2d (out_bn0): 1-6 [2, 8, 128, 128] (recursive) ├─ModuleDict (block0): 1-3 -- 1,800 │ └─Conv2d (in_conv1): 2-7 [2, 8, 128, 128] (recursive) │ └─BatchNorm2d (in_bn1): 2-8 [2, 8, 128, 128] (recursive) │ └─Conv2d (in_conv2): 2-9 [2, 8, 128, 128] (recursive) │ └─BatchNorm2d (in_bn2): 2-10 [2, 8, 128, 128] (recursive) │ └─Conv2d (in_conv3): 2-11 [2, 8, 128, 128] (recursive) │ └─BatchNorm2d (in_bn3): 2-12 [2, 8, 128, 128] (recursive) ├─ModuleDict (block1): 1-4 -- 1,800 ├─Conv2d (out_conv0): 1-5 [2, 8, 128, 128] (recursive) ├─BatchNorm2d (out_bn0): 1-6 [2, 8, 128, 128] (recursive) ├─ModuleDict (block0): 1-3 -- 1,800 │ └─Conv2d (in_conv1): 2-7 [2, 8, 128, 128] (recursive) │ └─BatchNorm2d (in_bn1): 2-8 [2, 8, 128, 128] (recursive) │ └─Conv2d (in_conv2): 2-9 [2, 8, 128, 128] (recursive) │ └─BatchNorm2d (in_bn2): 2-10 [2, 8, 128, 128] (recursive) │ └─Conv2d (in_conv3): 2-11 [2, 8, 128, 128] (recursive) │ └─BatchNorm2d (in_bn3): 2-12 [2, 8, 128, 128] (recursive) ├─ModuleDict (block1): 1-4 -- 1,800 │ └─Conv2d (in_conv4): 2-13 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn4): 2-14 [2, 8, 128, 128] 16 │ └─Conv2d (in_conv5): 2-15 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn5): 2-16 [2, 8, 128, 128] 16 │ └─Conv2d (in_conv6): 2-17 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn6): 2-18 [2, 8, 128, 128] 16 ├─Conv2d (out_conv7): 1-7 [2, 1, 128, 128] 9 ├─BatchNorm2d (out_bn7): 1-8 [2, 1, 128, 128] 2

Total params: 4,235 Trainable params: 4,235 Non-trainable params: 0 Total mult-adds (M): 212.37

Input size (MB): 0.39 Forward/backward pass size (MB): 13.11 Params size (MB): 0.01 Estimated Total Size (MB): 13.51


**Expected behavior**
Each `nn.Module` (they are all used only once) should appear only once, possibly in traversal order and with `depth-idx`, `Output Shape` and `Param #` correctly reported in a logical way:

========================================================================================== Layer (type (var_name):depth-idx) Output Shape Param #

RecursiveTest (RecursiveTest) [2, 1, 128, 128] -- ├─Conv2d (out_conv0): 1-1 [2, 8, 128, 128] 608 ├─BatchNorm2d (out_bn0): 1-2 [2, 8, 128, 128] 16 ├─ModuleDict (block0): 1-3 -- 1,800 │ └─Conv2d (in_conv1): 2-1 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn1): 2-2 [2, 8, 128, 128] 16 │ └─Conv2d (in_conv2): 2-3 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn2): 2-4 [2, 8, 128, 128] 16 │ └─Conv2d (in_conv3): 2-5 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn3): 2-6 [2, 8, 128, 128] 16 ├─ModuleDict (block1): 1-4 -- 1,800 │ └─Conv2d (in_conv4): 2-7 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn4): 2-8 [2, 8, 128, 128] 16 │ └─Conv2d (in_conv5): 2-9 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn5): 2-10 [2, 8, 128, 128] 16 │ └─Conv2d (in_conv6): 2-11 [2, 8, 128, 128] 584 │ └─BatchNorm2d (in_bn6): 2-12 [2, 8, 128, 128] 16 ├─Conv2d (out_conv7): 1-5 [2, 1, 128, 128] 9 ├─BatchNorm2d (out_bn7): 1-6 [2, 1, 128, 128] 2

Total params: 4,235 Trainable params: 4,235 Non-trainable params: 0 Total mult-adds (M): 212.37

Input size (MB): 0.39 Forward/backward pass size (MB): 13.11 Params size (MB): 0.01 Estimated Total Size (MB): 13.51



**Desktop:**
 - OS: Windows 10
 - Version 20H2

**Additional context**
By adding a `print(id(self))` in the `forward(self, input)` of `nn.Conv2d` (`nn._BatchNorm`) at `~\Anaconda3\envs\recursivetest\Lib\site-packages\torch\nn\modules\conv.py` (`batchnorm.py`), all modules are confirmed to be used only once (all unique id's).
CappellatoAlessio commented 2 years ago

And things get even worse if:

  1. Run summary(my_nn, input_data=[random_data], row_settings=('depth', 'var_names'), verbose=2)
  2. See output:
    ==========================================================================================
    Layer (type (var_name):depth-idx)        Output Shape              Param #
    ==========================================================================================
    RecursiveTest (RecursiveTest)            [2, 1, 128, 128]          --
    ├─Conv2d (out_conv0): 1-5                [2, 8, 128, 128]          (recursive)
    │    └─weight                                                      ├─600
    │    └─bias                                                        └─8
    ├─BatchNorm2d (out_bn0): 1-6             [2, 8, 128, 128]          (recursive)
    │    └─weight                                                      ├─8
    │    └─bias                                                        └─8
    ├─ModuleDict (block0): 1-3               --                        1,800
    │    └─in_conv1.weight                                             ├─576
    │    └─in_conv1.bias                                               ├─8
    │    └─in_bn1.weight                                               ├─8
    │    └─in_bn1.bias                                                 ├─8
    │    └─in_conv2.weight                                             ├─576
    │    └─in_conv2.bias                                               ├─8
    │    └─in_bn2.weight                                               ├─8
    │    └─in_bn2.bias                                                 ├─8
    │    └─in_conv3.weight                                             ├─576
    │    └─in_conv3.bias                                               ├─8
    │    └─in_bn3.weight                                               ├─8
    │    └─in_bn3.bias                                                 └─8
    │    └─Conv2d (in_conv1): 2-7            [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn1): 2-8         [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    │    └─Conv2d (in_conv2): 2-9            [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn2): 2-10        [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    │    └─Conv2d (in_conv3): 2-11           [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn3): 2-12        [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    ├─ModuleDict (block1): 1-4               --                        1,800
    │    └─in_conv4.weight                                             ├─576
    │    └─in_conv4.bias                                               ├─8
    │    └─in_bn4.weight                                               ├─8
    │    └─in_bn4.bias                                                 ├─8
    │    └─in_conv5.weight                                             ├─576
    │    └─in_conv5.bias                                               ├─8
    │    └─in_bn5.weight                                               ├─8
    │    └─in_bn5.bias                                                 ├─8
    │    └─in_conv6.weight                                             ├─576
    │    └─in_conv6.bias                                               ├─8
    │    └─in_bn6.weight                                               ├─8
    │    └─in_bn6.bias                                                 └─8
    ├─Conv2d (out_conv0): 1-5                [2, 8, 128, 128]          (recursive)
    │    └─weight                                                      ├─600
    │    └─bias                                                        └─8
    ├─BatchNorm2d (out_bn0): 1-6             [2, 8, 128, 128]          (recursive)
    │    └─weight                                                      ├─8
    │    └─bias                                                        └─8
    ├─ModuleDict (block0): 1-3               --                        1,800
    │    └─in_conv1.weight                                             ├─576
    │    └─in_conv1.bias                                               ├─8
    │    └─in_bn1.weight                                               ├─8
    │    └─in_bn1.bias                                                 ├─8
    │    └─in_conv2.weight                                             ├─576
    │    └─in_conv2.bias                                               ├─8
    │    └─in_bn2.weight                                               ├─8
    │    └─in_bn2.bias                                                 ├─8
    │    └─in_conv3.weight                                             ├─576
    │    └─in_conv3.bias                                               ├─8
    │    └─in_bn3.weight                                               ├─8
    │    └─in_bn3.bias                                                 └─8
    │    └─Conv2d (in_conv1): 2-7            [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn1): 2-8         [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    │    └─Conv2d (in_conv2): 2-9            [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn2): 2-10        [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    │    └─Conv2d (in_conv3): 2-11           [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn3): 2-12        [2, 8, 128, 128]          (recursive)
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    ├─ModuleDict (block1): 1-4               --                        1,800
    │    └─in_conv4.weight                                             ├─576
    │    └─in_conv4.bias                                               ├─8
    │    └─in_bn4.weight                                               ├─8
    │    └─in_bn4.bias                                                 ├─8
    │    └─in_conv5.weight                                             ├─576
    │    └─in_conv5.bias                                               ├─8
    │    └─in_bn5.weight                                               ├─8
    │    └─in_bn5.bias                                                 ├─8
    │    └─in_conv6.weight                                             ├─576
    │    └─in_conv6.bias                                               ├─8
    │    └─in_bn6.weight                                               ├─8
    │    └─in_bn6.bias                                                 └─8
    │    └─Conv2d (in_conv4): 2-13           [2, 8, 128, 128]          584
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn4): 2-14        [2, 8, 128, 128]          16
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    │    └─Conv2d (in_conv5): 2-15           [2, 8, 128, 128]          584
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn5): 2-16        [2, 8, 128, 128]          16
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    │    └─Conv2d (in_conv6): 2-17           [2, 8, 128, 128]          584
    │    │    └─weight                                                 ├─576
    │    │    └─bias                                                   └─8
    │    └─BatchNorm2d (in_bn6): 2-18        [2, 8, 128, 128]          16
    │    │    └─weight                                                 ├─8
    │    │    └─bias                                                   └─8
    ├─Conv2d (out_conv7): 1-7                [2, 1, 128, 128]          9
    │    └─weight                                                      ├─8
    │    └─bias                                                        └─1
    ├─BatchNorm2d (out_bn7): 1-8             [2, 1, 128, 128]          2
    │    └─weight                                                      ├─1
    │    └─bias                                                        └─1
    ==========================================================================================
    Total params: 4,235
    Trainable params: 4,235
    Non-trainable params: 0
    Total mult-adds (M): 212.37
    ==========================================================================================
    Input size (MB): 0.39
    Forward/backward pass size (MB): 13.11
    Params size (MB): 0.01
    Estimated Total Size (MB): 13.51
    ==========================================================================================
TylerYep commented 2 years ago

Thanks for reporting this issue! PRs to fix this are much appreciated.

mert-kurttutan commented 2 years ago

Hi,

I have a potential solution that solves this issue and works on passes test cases. While working on this solution, I have found some other cases that are problematic for both the current implementation and my solution. But, to be really certain about my solution, I need to ask a few general question, e.g., regarding add_missing_layers() function in torchinfo.py.

How should I proceed? Do I continue this general discussion here or open another issue for it?

TylerYep commented 2 years ago

Feel free to open the PR (even if it is a draft) and we can discuss the issues on the PR itself