TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.56k stars 119 forks source link

Support summaries for nested `ModuleDict`s #184

Closed justinxzhao closed 1 year ago

justinxzhao commented 1 year ago

The Ludwig project (ludwig.ai, https://github.com/ludwig-ai/ludwig) uses nested ModuleDicts to organize and manage internal neural architectures.

Here's some sample output of torchinfo. Many torch modules are collapsed:

==================================================================================================================================
Layer (type:depth-idx)                                                           Output Shape              Param #
==================================================================================================================================
ECD                                                                              [2]                       --
├─ConcatCombiner: 1-1                                                            --                        (recursive)
│    └─LudwigFeatureDict: 2-1                                                    --                        --
│    │    └─ModuleDict: 3-1                                                      --                        66,362,880
├─ConcatCombiner: 1-2                                                            [2]                       66,362,880
│    └─FCStack: 2-2                                                              [2, 128]                  --
│    │    └─ModuleList: 3-2                                                      --                        131,456
├─LudwigFeatureDict: 1-3                                                         --                        --
│    └─ModuleDict: 2-3                                                           --                        --
│    │    └─BinaryOutputFeature: 3-3                                             [2]                       7,329
==================================================================================================================================
Total params: 66,501,665
Trainable params: 66,501,665
Non-trainable params: 0
Total mult-adds (M): 132.61
==================================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 12.63
Params size (MB): 266.01
Estimated Total Size (MB): 278.64
==================================================================================================================================

Would this be difficult to build support for?

TylerYep commented 1 year ago

Have you tried setting a larger depth?, e.g. depth=20

justinxzhao commented 1 year ago

Ah! Great tip -- depth=20 seems to work wonderfully.

=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
ECD                                                     [2]                       --
├─ConcatCombiner: 1-1                                   --                        (recursive)
│    └─LudwigFeatureDict: 2-1                           --                        --
│    │    └─ModuleDict: 3-1                             --                        --
│    │    │    └─NumberInputFeature: 4-1                [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-1           [2, 1]                    --
│    │    │    └─CategoryInputFeature: 4-2              [2, 9]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-2      [2, 9]                    --
│    │    │    │    │    └─Embed: 6-1                   [2, 9]                    --
│    │    │    │    │    │    └─Embedding: 7-1          [2, 1, 9]                 81
│    │    │    └─NumberInputFeature: 4-3                [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-3           [2, 1]                    --
│    │    │    └─CategoryInputFeature: 4-4              [2, 16]                   --
│    │    │    │    └─CategoricalEmbedEncoder: 5-4      [2, 16]                   --
│    │    │    │    │    └─Embed: 6-2                   [2, 16]                   --
│    │    │    │    │    │    └─Embedding: 7-2          [2, 1, 16]                256
│    │    │    └─NumberInputFeature: 4-5                [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-5           [2, 1]                    --
│    │    │    └─CategoryInputFeature: 4-6              [2, 7]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-6      [2, 7]                    --
│    │    │    │    │    └─Embed: 6-3                   [2, 7]                    --
│    │    │    │    │    │    └─Embedding: 7-3          [2, 1, 7]                 49
│    │    │    └─CategoryInputFeature: 4-7              [2, 15]                   --
│    │    │    │    └─CategoricalEmbedEncoder: 5-7      [2, 15]                   --
│    │    │    │    │    └─Embed: 6-4                   [2, 15]                   --
│    │    │    │    │    │    └─Embedding: 7-4          [2, 1, 15]                225
│    │    │    └─CategoryInputFeature: 4-8              [2, 6]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-8      [2, 6]                    --
│    │    │    │    │    └─Embed: 6-5                   [2, 6]                    --
│    │    │    │    │    │    └─Embedding: 7-5          [2, 1, 6]                 36
│    │    │    └─CategoryInputFeature: 4-9              [2, 5]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-9      [2, 5]                    --
│    │    │    │    │    └─Embed: 6-6                   [2, 5]                    --
│    │    │    │    │    │    └─Embedding: 7-6          [2, 1, 5]                 25
│    │    │    └─CategoryInputFeature: 4-10             [2, 2]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-10     [2, 2]                    --
│    │    │    │    │    └─Embed: 6-7                   [2, 2]                    --
│    │    │    │    │    │    └─Embedding: 7-7          [2, 1, 2]                 4
│    │    │    └─NumberInputFeature: 4-11               [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-11          [2, 1]                    --
│    │    │    └─NumberInputFeature: 4-12               [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-12          [2, 1]                    --
│    │    │    └─NumberInputFeature: 4-13               [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-13          [2, 1]                    --
│    │    │    └─CategoryInputFeature: 4-14             [2, 42]                   --
│    │    │    │    └─CategoricalEmbedEncoder: 5-14     [2, 42]                   --
│    │    │    │    │    └─Embed: 6-8                   [2, 42]                   --
│    │    │    │    │    │    └─Embedding: 7-8          [2, 1, 42]                1,764
├─ConcatCombiner: 1-2                                   [2, 128]                  2,440
│    └─FCStack: 2-2                                     [2, 128]                  --
│    │    └─ModuleList: 3-2                             --                        --
│    │    │    └─FCLayer: 4-15                          [2, 128]                  --
│    │    │    │    └─ModuleList: 5-15                  --                        --
│    │    │    │    │    └─Linear: 6-9                  [2, 128]                  13,952
│    │    │    │    │    └─ReLU: 6-10                   [2, 128]                  --
│    │    │    │    │    └─Dropout: 6-11                [2, 128]                  --
│    │    │    └─FCLayer: 4-16                          [2, 128]                  --
│    │    │    │    └─ModuleList: 5-16                  --                        --
│    │    │    │    │    └─Linear: 6-12                 [2, 128]                  16,512
│    │    │    │    │    └─ReLU: 6-13                   [2, 128]                  --
│    │    │    │    │    └─Dropout: 6-14                [2, 128]                  --
│    │    │    └─FCLayer: 4-17                          [2, 128]                  --
│    │    │    │    └─ModuleList: 5-17                  --                        --
│    │    │    │    │    └─Linear: 6-15                 [2, 128]                  16,512
│    │    │    │    │    └─ReLU: 6-16                   [2, 128]                  --
│    │    │    │    │    └─Dropout: 6-17                [2, 128]                  --
├─LudwigFeatureDict: 1-3                                --                        --
│    └─ModuleDict: 2-3                                  --                        --
│    │    └─BinaryOutputFeature: 3-3                    [2]                       --
│    │    │    └─FCStack: 4-18                          [2, 32]                   --
│    │    │    │    └─ModuleList: 5-18                  --                        --
│    │    │    │    │    └─FCLayer: 6-18                [2, 32]                   --
│    │    │    │    │    │    └─ModuleList: 7-9         --                        --
│    │    │    │    │    │    │    └─Linear: 8-1        [2, 32]                   4,128
│    │    │    │    │    │    │    └─ReLU: 8-2          [2, 32]                   --
│    │    │    │    │    └─FCLayer: 6-19                [2, 32]                   --
│    │    │    │    │    │    └─ModuleList: 7-10        --                        --
│    │    │    │    │    │    │    └─Linear: 8-3        [2, 32]                   1,056
│    │    │    │    │    │    │    └─ReLU: 8-4          [2, 32]                   --
│    │    │    │    │    └─FCLayer: 6-20                [2, 32]                   --
│    │    │    │    │    │    └─ModuleList: 7-11        --                        --
│    │    │    │    │    │    │    └─Linear: 8-5        [2, 32]                   1,056
│    │    │    │    │    │    │    └─ReLU: 8-6          [2, 32]                   --
│    │    │    │    │    └─FCLayer: 6-21                [2, 32]                   --
│    │    │    │    │    │    └─ModuleList: 7-12        --                        --
│    │    │    │    │    │    │    └─Linear: 8-7        [2, 32]                   1,056
│    │    │    │    │    │    │    └─ReLU: 8-8          [2, 32]                   --
│    │    │    └─Regressor: 4-19                        [2]                       --
│    │    │    │    └─Dense: 5-19                       [2]                       --
│    │    │    │    │    └─Linear: 6-22                 [2, 1]                    33
=========================================================================================================
Total params: 56,745
Trainable params: 56,745
Non-trainable params: 0
Total mult-adds (M): 0.11
=========================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.23
Estimated Total Size (MB): 0.24
=========================================================================================================