ludwig-ai / ludwig

Low-code framework for building custom LLMs, neural networks, and other AI models
http://ludwig.ai
Apache License 2.0
11.14k stars 1.19k forks source link

Is it possible to visualize the model network? #2652

Closed qiagu closed 1 year ago

qiagu commented 2 years ago

Being able to use tools, like Torchviz, to generate model structure graph in Ludwig would be great.

tgaddair commented 2 years ago

Hey @qiagu, we have written some tools for generating model visualizations that look like this, though we haven't yet added them officially to Ludwig:

https://raw.githubusercontent.com/ludwig-ai/ludwig-docs/master/docs/images/ludwig_legos_unanimated.gif

Is this the kind of visualization you're looking for, or something more granular?

qiagu commented 2 years ago

@tgaddair I was interested in more detailed model structure for reasons like, 1) compare with models created directly with Torch or Tensorflow, 2) visually check whether a customized Ludwig model is built as planned. The lego visualization you posted looks awesome as well.

justinxzhao commented 2 years ago

Hi @qiagu,

I checked out TorchViz. LudwigModels can certainly provide the underlying torch Module model as well as sample inputs, which is how we support torchscript compilation.

However, it looks like the tool doesn't support dictionary-based input tracing yet. :(

from ludwig.api import LudwigModel
from torchviz import make_dot, make_dot_from_trace

model = LudwigModel(config=...)

make_dot(
    model.model(model.model.get_model_inputs()), 
    params=dict(model.model.named_parameters())
)

Results in this error:

TypeError                                 Traceback (most recent call last)
[<ipython-input-26-24af005d1343>](https://localhost:8080/#) in <module>
      4 print(model.model.get_model_inputs())
      5 
----> 6 make_dot(model.model(model.model.get_model_inputs()), params=dict(model.model.named_parameters()))

1 frames
[/usr/local/lib/python3.7/dist-packages/torchviz/dot.py](https://localhost:8080/#) in add_base_tensor(var, color)
    144 
    145     def add_base_tensor(var, color='darkolivegreen1'):
--> 146         if var in seen:
    147             return
    148         seen.add(var)

TypeError: unhashable type: 'dict'

It could be worth filing an issue to the TorchViz project, linking to this issue, WDYT?

qiagu commented 2 years ago

@justinxzhao Great tries and I appreciate that a lot. Could you please try the pytorch-summary as well? It would be great to see a summary out of a Ludwig model. Refer to https://stackoverflow.com/questions/42480111/how-do-i-print-the-model-summary-in-pytorch

justinxzhao commented 2 years ago

Hi @qiagu,

torchinfo.summary works.

print(torchinfo.summary(model.model, input_data=[model.model.get_model_inputs()]))

Here's an example output:

=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
ECD                                                     [2]                       --
├─ConcatCombiner: 1-1                                   --                        (recursive)
│    └─LudwigFeatureDict: 2-1                           --                        --
│    │    └─ModuleDict: 3-1                             --                        2,440
├─ConcatCombiner: 1-2                                   [2, 128]                  2,440
│    └─FCStack: 2-2                                     [2, 128]                  --
│    │    └─ModuleList: 3-2                             --                        46,976
├─LudwigFeatureDict: 1-3                                --                        --
│    └─ModuleDict: 2-3                                  --                        --
│    │    └─BinaryOutputFeature: 3-3                    [2]                       7,329
=========================================================================================================
Total params: 56,745
Trainable params: 56,745
Non-trainable params: 0
Total mult-adds (M): 0.11
=========================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.23
Estimated Total Size (MB): 0.24
=========================================================================================================
qiagu commented 2 years ago

@justinxzhao Thanks for your nice work. It looks the summary doesn't add lots of extra info over the lego animation as @tgaddair showed above.

justinxzhao commented 1 year ago

I filed an issue to torchinfo to assess how difficult it would be to add support for nested ModuleDicts.

https://github.com/TylerYep/torchinfo/issues/184

justinxzhao commented 1 year ago

@qiagu It looks like you can get more detailed torchinfo with depth=20

print(
    torchinfo.summary(
        model.model, input_data=[model.model.get_model_inputs()],
        depth=20
    )
)

Sample:

=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
ECD                                                     [2]                       --
├─ConcatCombiner: 1-1                                   --                        (recursive)
│    └─LudwigFeatureDict: 2-1                           --                        --
│    │    └─ModuleDict: 3-1                             --                        --
│    │    │    └─NumberInputFeature: 4-1                [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-1           [2, 1]                    --
│    │    │    └─CategoryInputFeature: 4-2              [2, 9]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-2      [2, 9]                    --
│    │    │    │    │    └─Embed: 6-1                   [2, 9]                    --
│    │    │    │    │    │    └─Embedding: 7-1          [2, 1, 9]                 81
│    │    │    └─NumberInputFeature: 4-3                [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-3           [2, 1]                    --
│    │    │    └─CategoryInputFeature: 4-4              [2, 16]                   --
│    │    │    │    └─CategoricalEmbedEncoder: 5-4      [2, 16]                   --
│    │    │    │    │    └─Embed: 6-2                   [2, 16]                   --
│    │    │    │    │    │    └─Embedding: 7-2          [2, 1, 16]                256
│    │    │    └─NumberInputFeature: 4-5                [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-5           [2, 1]                    --
│    │    │    └─CategoryInputFeature: 4-6              [2, 7]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-6      [2, 7]                    --
│    │    │    │    │    └─Embed: 6-3                   [2, 7]                    --
│    │    │    │    │    │    └─Embedding: 7-3          [2, 1, 7]                 49
│    │    │    └─CategoryInputFeature: 4-7              [2, 15]                   --
│    │    │    │    └─CategoricalEmbedEncoder: 5-7      [2, 15]                   --
│    │    │    │    │    └─Embed: 6-4                   [2, 15]                   --
│    │    │    │    │    │    └─Embedding: 7-4          [2, 1, 15]                225
│    │    │    └─CategoryInputFeature: 4-8              [2, 6]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-8      [2, 6]                    --
│    │    │    │    │    └─Embed: 6-5                   [2, 6]                    --
│    │    │    │    │    │    └─Embedding: 7-5          [2, 1, 6]                 36
│    │    │    └─CategoryInputFeature: 4-9              [2, 5]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-9      [2, 5]                    --
│    │    │    │    │    └─Embed: 6-6                   [2, 5]                    --
│    │    │    │    │    │    └─Embedding: 7-6          [2, 1, 5]                 25
│    │    │    └─CategoryInputFeature: 4-10             [2, 2]                    --
│    │    │    │    └─CategoricalEmbedEncoder: 5-10     [2, 2]                    --
│    │    │    │    │    └─Embed: 6-7                   [2, 2]                    --
│    │    │    │    │    │    └─Embedding: 7-7          [2, 1, 2]                 4
│    │    │    └─NumberInputFeature: 4-11               [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-11          [2, 1]                    --
│    │    │    └─NumberInputFeature: 4-12               [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-12          [2, 1]                    --
│    │    │    └─NumberInputFeature: 4-13               [2, 1]                    --
│    │    │    │    └─PassthroughEncoder: 5-13          [2, 1]                    --
│    │    │    └─CategoryInputFeature: 4-14             [2, 42]                   --
│    │    │    │    └─CategoricalEmbedEncoder: 5-14     [2, 42]                   --
│    │    │    │    │    └─Embed: 6-8                   [2, 42]                   --
│    │    │    │    │    │    └─Embedding: 7-8          [2, 1, 42]                1,764
├─ConcatCombiner: 1-2                                   [2, 128]                  2,440
│    └─FCStack: 2-2                                     [2, 128]                  --
│    │    └─ModuleList: 3-2                             --                        --
│    │    │    └─FCLayer: 4-15                          [2, 128]                  --
│    │    │    │    └─ModuleList: 5-15                  --                        --
│    │    │    │    │    └─Linear: 6-9                  [2, 128]                  13,952
│    │    │    │    │    └─ReLU: 6-10                   [2, 128]                  --
│    │    │    │    │    └─Dropout: 6-11                [2, 128]                  --
│    │    │    └─FCLayer: 4-16                          [2, 128]                  --
│    │    │    │    └─ModuleList: 5-16                  --                        --
│    │    │    │    │    └─Linear: 6-12                 [2, 128]                  16,512
│    │    │    │    │    └─ReLU: 6-13                   [2, 128]                  --
│    │    │    │    │    └─Dropout: 6-14                [2, 128]                  --
│    │    │    └─FCLayer: 4-17                          [2, 128]                  --
│    │    │    │    └─ModuleList: 5-17                  --                        --
│    │    │    │    │    └─Linear: 6-15                 [2, 128]                  16,512
│    │    │    │    │    └─ReLU: 6-16                   [2, 128]                  --
│    │    │    │    │    └─Dropout: 6-17                [2, 128]                  --
├─LudwigFeatureDict: 1-3                                --                        --
│    └─ModuleDict: 2-3                                  --                        --
│    │    └─BinaryOutputFeature: 3-3                    [2]                       --
│    │    │    └─FCStack: 4-18                          [2, 32]                   --
│    │    │    │    └─ModuleList: 5-18                  --                        --
│    │    │    │    │    └─FCLayer: 6-18                [2, 32]                   --
│    │    │    │    │    │    └─ModuleList: 7-9         --                        --
│    │    │    │    │    │    │    └─Linear: 8-1        [2, 32]                   4,128
│    │    │    │    │    │    │    └─ReLU: 8-2          [2, 32]                   --
│    │    │    │    │    └─FCLayer: 6-19                [2, 32]                   --
│    │    │    │    │    │    └─ModuleList: 7-10        --                        --
│    │    │    │    │    │    │    └─Linear: 8-3        [2, 32]                   1,056
│    │    │    │    │    │    │    └─ReLU: 8-4          [2, 32]                   --
│    │    │    │    │    └─FCLayer: 6-20                [2, 32]                   --
│    │    │    │    │    │    └─ModuleList: 7-11        --                        --
│    │    │    │    │    │    │    └─Linear: 8-5        [2, 32]                   1,056
│    │    │    │    │    │    │    └─ReLU: 8-6          [2, 32]                   --
│    │    │    │    │    └─FCLayer: 6-21                [2, 32]                   --
│    │    │    │    │    │    └─ModuleList: 7-12        --                        --
│    │    │    │    │    │    │    └─Linear: 8-7        [2, 32]                   1,056
│    │    │    │    │    │    │    └─ReLU: 8-8          [2, 32]                   --
│    │    │    └─Regressor: 4-19                        [2]                       --
│    │    │    │    └─Dense: 5-19                       [2]                       --
│    │    │    │    │    └─Linear: 6-22                 [2, 1]                    33
=========================================================================================================
Total params: 56,745
Trainable params: 56,745
Non-trainable params: 0
Total mult-adds (M): 0.11
=========================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.23
Estimated Total Size (MB): 0.24
=========================================================================================================

Hopefully this is more useful to you :)

qiagu commented 1 year ago

@justinxzhao Awesome! Thank you for solving the problem perfectly.