TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.5k stars 118 forks source link

Misidentification of recursiveness (?) #262

Open kswannet opened 1 year ago

kswannet commented 1 year ago

When using more than one nn.sequential modules, and they both use the same activation functions defined under init, torchinfo splits the single nn.sequential in separate ones at each activation function call.

For example:

class TestNetwork(nn.Module):
    def __init__(self, *, input_dim, h_dim, output_dim, actFun=nn.LeakyReLU(), dropRate=0.3):

        super().__init__()
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.output_dim = output_dim
        self.actFun = actFun
        self.dropRate = dropRate

        layer_dim = [input_dim] + h_dim + [output_dim]   # lengths of all layers

        self.firstNetwork = nn.Sequential()
        for i in range(len(h_dim)+1):
            self.firstNetwork.add_module(f"fcFirst_{i}", nn.Linear(layer_dim[i], layer_dim[i+1]))
            self.firstNetwork.add_module(f"actFirst_{i}", self.actFun)
            self.firstNetwork.add_module(f"dropFirst_{i}", nn.Dropout(dropRate))

        self.secondNetwork = nn.Sequential()
        for i in range(len(h_dim)+1):
            self.secondNetwork.add_module(f"fcSecnd_{i}", nn.Linear(layer_dim[i], layer_dim[i+1]))
            self.secondNetwork.add_module(f"actSecnd_{i}", self.actFun)
            self.secondNetwork.add_module(f"dropSecnd_{i}", nn.Dropout(dropRate))
        del self.secondNetwork[-1]            # remove last dropout layer

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.firstNetwork(x)

if __name__ == "__main__":
    net = TestNetwork(
        input_dim=199,
        h_dim=[64, 32],
        output_dim=100,
        )
    print(str(summary(net, input_size=[16, 199])))

results in the following summary:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
TestNetwork                              [16, 100]                 12,800
├─Sequential: 1-1                        [16, 100]                 5,380
│    └─Linear: 2-1                       [16, 64]                  12,800
├─Sequential: 1-6                        --                        (recursive)
│    └─LeakyReLU: 2-2                    [16, 64]                  --
├─Sequential: 1-7                        --                        (recursive)
│    └─Dropout: 2-3                      [16, 64]                  --
│    └─Linear: 2-4                       [16, 32]                  2,080
├─Sequential: 1-6                        --                        (recursive)
│    └─LeakyReLU: 2-5                    [16, 32]                  --
├─Sequential: 1-7                        --                        (recursive)
│    └─Dropout: 2-6                      [16, 32]                  --
│    └─Linear: 2-7                       [16, 100]                 3,300
├─Sequential: 1-6                        --                        (recursive)
│    └─LeakyReLU: 2-8                    [16, 100]                 --
├─Sequential: 1-7                        --                        (recursive)
│    └─Dropout: 2-9                      [16, 100]                 --
==========================================================================================

Even though the secondNetwork is unused, changing one of the self.actFun calls to e.g. nn.LeakyReLU fixes the problem:

class TestNetwork(nn.Module):
    def __init__(self, *, input_dim, h_dim, output_dim, actFun=nn.LeakyReLU(), dropRate=0.3):

        super().__init__()
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.output_dim = output_dim
        self.actFun = actFun
        self.dropRate = dropRate

        layer_dim = [input_dim] + h_dim + [output_dim]   # lengths of all layers

        self.firstNetwork = nn.Sequential()
        for i in range(len(h_dim)+1):
            self.firstNetwork.add_module(f"fcFirst_{i}", nn.Linear(layer_dim[i], layer_dim[i+1]))
            self.firstNetwork.add_module(f"actFirst_{i}", self.actFun)
            self.firstNetwork.add_module(f"dropFirst_{i}", nn.Dropout(dropRate))

        self.secondNetwork = nn.Sequential()
        for i in range(len(h_dim)+1):
            self.secondNetwork.add_module(f"fcSecnd_{i}", nn.Linear(layer_dim[i], layer_dim[i+1]))
            self.secondNetwork.add_module(f"actSecnd_{i}", nn.LeakyReLU())  # Change activation function call
            self.secondNetwork.add_module(f"dropSecnd_{i}", nn.Dropout(dropRate))
        del self.secondNetwork[-1]            # remove last dropout layer

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.firstNetwork(x)

if __name__ == "__main__":
    net = TestNetwork(
        input_dim=199,
        h_dim=[64, 32],
        output_dim=100,
        )
    print(str(summary(net, input_size=[16, 199])))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
TestNetwork                              [16, 100]                 18,180
├─Sequential: 1-1                        [16, 100]                 --
│    └─Linear: 2-1                       [16, 64]                  12,800
│    └─LeakyReLU: 2-2                    [16, 64]                  --
│    └─Dropout: 2-3                      [16, 64]                  --
│    └─Linear: 2-4                       [16, 32]                  2,080
│    └─LeakyReLU: 2-5                    [16, 32]                  --
│    └─Dropout: 2-6                      [16, 32]                  --
│    └─Linear: 2-7                       [16, 100]                 3,300
│    └─LeakyReLU: 2-8                    [16, 100]                 --
│    └─Dropout: 2-9                      [16, 100]                 --
==========================================================================================

Is this a torchinfo problem or am I maybe doing something wrong here?

mykappa commented 1 year ago

I have exactly the same problem. With more complex models, the output becomes completely confusing and not well comprehensible.

Update: I've found that adding the the option "hide_recursive_layers" to the row_settings did improve the output a lot for my case.