TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.59k stars 121 forks source link

Output incorrect when using nn.ModuleList #170

Closed Ben-Drucker closed 2 years ago

Ben-Drucker commented 2 years ago

Describe the bug When using an nn.ModuleList inside a "main" module, the output from summary seems to be incorrect if the submodules of the modules used in the ModuleList are not declared in the same order they are called in forward.

To Reproduce To get incorrect output, one can run

import torchinfo
from torch import nn

class MainModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.ml = nn.ModuleList([CNN() for i in range(5)]) # ModuleList of CNNs

    def forward(self, x):
        for l in self.ml:
            x = l(x)
        return x

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv = nn.Conv1d(1, 1, 1) # CNN submodules - note that ReLU is instantiated before Conv1d
        self.pool = nn.MaxPool1d(1)

    def forward(self, x): # The forward function of the modules in the `ModuleList`
        out = x
        out = self.conv(out)
        out = self.relu(out)
        out = self.pool(out)
        return out

def main():
    model = MainModule()
    torchinfo.summary(model, input_size=[1, 10])

if __name__ == "__main__":
    main()

The output is as follows. Note that the numbering for the CNNs is inconsistent, some items are out of order, etc.

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MainModule                               [1, 10]                   --
├─ModuleList: 1-1                        --                        --
│    └─CNN: 2-1                          [1, 10]                   --
│    │    └─ReLU: 3-3                    [1, 10]                   --
│    │    └─Conv1d: 3-2                  [1, 10]                   2
│    │    └─ReLU: 3-3                    [1, 10]                   --
│    │    └─MaxPool1d: 3-4               [1, 10]                   --
│    └─CNN: 2                            --                        --
│    │    └─ReLU: 3-7                    [1, 10]                   --
│    └─CNN: 2-2                          [1, 10]                   --
│    │    └─Conv1d: 3-6                  [1, 10]                   2
│    │    └─ReLU: 3-7                    [1, 10]                   --
│    └─CNN: 2                            --                        --
│    │    └─ReLU: 3-11                   [1, 10]                   --
│    └─CNN: 2                            --                        --
│    │    └─MaxPool1d: 3-9               [1, 10]                   --
│    └─CNN: 2-3                          [1, 10]                   --
│    │    └─Conv1d: 3-10                 [1, 10]                   2
│    │    └─ReLU: 3-11                   [1, 10]                   --
│    │    └─MaxPool1d: 3-12              [1, 10]                   --
==========================================================================================
Total params: 6
Trainable params: 6
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

Expected behavior Swapping

        self.relu = nn.ReLU()
        self.conv = nn.Conv1d(1, 1, 1) # CNN submodules - note that ReLU is instantiated BEFORE Conv1d
        self.pool = nn.MaxPool1d(1)

with

        self.conv = nn.Conv1d(1, 1, 1) # CNN submodules - note that ReLU is instantiated AFTER Conv1d
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(1)

seems to resolve the issue, and outputs

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MainModule                               [1, 10]                   --
├─ModuleList: 1-1                        --                        --
│    └─CNN: 2-1                          [1, 10]                   --
│    │    └─Conv1d: 3-1                  [1, 10]                   2
│    │    └─ReLU: 3-2                    [1, 10]                   --
│    │    └─MaxPool1d: 3-3               [1, 10]                   --
│    └─CNN: 2-2                          [1, 10]                   --
│    │    └─Conv1d: 3-4                  [1, 10]                   2
│    │    └─ReLU: 3-5                    [1, 10]                   --
│    │    └─MaxPool1d: 3-6               [1, 10]                   --
│    └─CNN: 2-3                          [1, 10]                   --
│    │    └─Conv1d: 3-7                  [1, 10]                   2
│    │    └─ReLU: 3-8                    [1, 10]                   --
│    │    └─MaxPool1d: 3-9               [1, 10]                   --
==========================================================================================
Total params: 6
Trainable params: 6
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

Screenshots N/A

Desktop (please complete the following information):

Additional context For my case, there is a workaround, but I wanted to make the issue known in case there doesn't exist a workaround for a particular usage. Also, I wanted to inform other users of the problem. Thank you for taking the time to read this issue!

TylerYep commented 2 years ago

Hi, this is a known issue and should be fixed by #169 . I'll add your example as a test case to ensure it works going forward.

TylerYep commented 2 years ago

Should be fixed in v1.7.1

Ben-Drucker commented 2 years ago

Thanks!