DrChainsaw / ONNXNaiveNASflux.jl

Import/export ONNX models
MIT License
43 stars 2 forks source link

Showing what Flux operations are used within a CompGraph #47

Closed vtjeng closed 2 months ago

vtjeng commented 3 years ago

I was hoping to use ONNXmutable.jl to extract the Flux operations specified in an .onnx file. Specifically, I'd like to be able to provide an .onnx file and load in a function corresponding to a composition of Flux primitives - something like

f = Chain(
  Dense(10, 5, σ),
  Dense(5, 2),
  softmax
)

What I've tried

Use Case

[1] I'm using Flux because it seems to be the best supported framework related to neural networks in Julia, but would be open to suggestions to consider other options.

Additional Notes

The layers that I'd like implemented are a subset of the operations that are supported (in fact, I'd really like to start with just a simple feedforward network with Gemm and Relu layers).

DrChainsaw commented 3 years ago

Hi and thanks for showing interest!

The CompGraph has many methods for introspection and most of them are documented in NaiveNASflux.

Here is an example of the type of ad-hoc summarytable i usually use:

julia> cg = CompGraph("resnet18-v1-7.onnx");

# Return vertices of cg as a topologically sorted array
julia> vs = vertices(cg);

# This is due to an oversight in NaiveNASflux which I will correct one day, otherwise it will fail when it hits a non-Flux layer like the element wise additions of the resnet
julia> NaiveNASflux.layer(f) = f;

julia> [name.(vs) nin.(vs) nout.(vs) layer.(vs) map(ivs -> name.(ivs), inputs.(vs)) map(ovs -> name.(ovs), outputs.(vs))]
61×6 Matrix{Any}:
 "data"                             Any[]          3  LayerTypeWrapper(FluxConv{2}())                       Any[]                                                                   ["resnetv15_conv0_fwd"]
 "resnetv15_conv0_fwd"              [3]           64  Conv((7, 7), 3=>64)                                   ["data"]                                                                ["resnetv15_batchnorm0_fwd"]
 "resnetv15_batchnorm0_fwd"         [64]          64  BatchNorm(64, λ = relu)                               ["resnetv15_conv0_fwd"]                                                 ["resnetv15_pool0_fwd"]
 "resnetv15_pool0_fwd"              [64]          64  MaxPool((3, 3), pad = (1, 1, 1, 1), stride = (2, 2))  ["resnetv15_batchnorm0_fwd"]                                            ["resnetv15_stage1_conv0_fwd", "resnetv15_stage1__plus0"]
 "resnetv15_stage1_conv0_fwd"       [64]          64  Conv((3, 3), 64=>64)                                  ["resnetv15_pool0_fwd"]                                                 ["resnetv15_stage1_batchnorm0_fwd"]
 "resnetv15_stage1_batchnorm0_fwd"  [64]          64  BatchNorm(64, λ = relu)                               ["resnetv15_stage1_conv0_fwd"]                                          ["resnetv15_stage1_conv1_fwd"]
 "resnetv15_stage1_conv1_fwd"       [64]          64  Conv((3, 3), 64=>64)                                  ["resnetv15_stage1_batchnorm0_fwd"]                                     ["resnetv15_stage1_batchnorm1_fwd"]
 "resnetv15_stage1_batchnorm1_fwd"  [64]          64  BatchNorm(64)                                         ["resnetv15_stage1_conv1_fwd"]                                          ["resnetv15_stage1__plus0"]
 "resnetv15_stage1__plus0"          [64, 64]      64  #225                                                  ["resnetv15_pool0_fwd", "resnetv15_stage1_batchnorm1_fwd"]              ["resnetv15_stage1_activation0"]
 "resnetv15_stage1_activation0"     [64]          64  #195                                                  ["resnetv15_stage1__plus0"]                                             ["resnetv15_stage1_conv2_fwd", "resnetv15_stage1__plus1"]
 "resnetv15_stage1_conv2_fwd"       [64]          64  Conv((3, 3), 64=>64)                                  ["resnetv15_stage1_activation0"]                                        ["resnetv15_stage1_batchnorm2_fwd"]
 "resnetv15_stage1_batchnorm2_fwd"  [64]          64  BatchNorm(64, λ = relu)                               ["resnetv15_stage1_conv2_fwd"]                                          ["resnetv15_stage1_conv3_fwd"]
 "resnetv15_stage1_conv3_fwd"       [64]          64  Conv((3, 3), 64=>64)                                  ["resnetv15_stage1_batchnorm2_fwd"]                                     ["resnetv15_stage1_batchnorm3_fwd"]
 "resnetv15_stage1_batchnorm3_fwd"  [64]          64  BatchNorm(64)                                         ["resnetv15_stage1_conv3_fwd"]                                          ["resnetv15_stage1__plus1"]
 "resnetv15_stage1__plus1"          [64, 64]      64  #225                                                  ["resnetv15_stage1_activation0", "resnetv15_stage1_batchnorm3_fwd"]     ["resnetv15_stage1_activation1"]
 "resnetv15_stage1_activation1"     [64]          64  #195                                                  ["resnetv15_stage1__plus1"]                                             ["resnetv15_stage2_conv2_fwd", "resnetv15_stage2_conv0_fwd"]
 "resnetv15_stage2_conv2_fwd"       [64]         128  Conv((1, 1), 64=>128)                                 ["resnetv15_stage1_activation1"]                                        ["resnetv15_stage2_batchnorm2_fwd"]
 "resnetv15_stage2_batchnorm2_fwd"  [128]        128  BatchNorm(128)                                        ["resnetv15_stage2_conv2_fwd"]                                          ["resnetv15_stage2__plus0"]
 "resnetv15_stage2_conv0_fwd"       [64]         128  Conv((3, 3), 64=>128)                                 ["resnetv15_stage1_activation1"]                                        ["resnetv15_stage2_batchnorm0_fwd"]
 "resnetv15_stage2_batchnorm0_fwd"  [128]        128  BatchNorm(128, λ = relu)                              ["resnetv15_stage2_conv0_fwd"]                                          ["resnetv15_stage2_conv1_fwd"]
 "resnetv15_stage2_conv1_fwd"       [128]        128  Conv((3, 3), 128=>128)                                ["resnetv15_stage2_batchnorm0_fwd"]                                     ["resnetv15_stage2_batchnorm1_fwd"]
 "resnetv15_stage2_batchnorm1_fwd"  [128]        128  BatchNorm(128)                                        ["resnetv15_stage2_conv1_fwd"]                                          ["resnetv15_stage2__plus0"]
 "resnetv15_stage2__plus0"          [128, 128]   128  #225                                                  ["resnetv15_stage2_batchnorm2_fwd", "resnetv15_stage2_batchnorm1_fwd"]  ["resnetv15_stage2_activation0"]
 "resnetv15_stage2_activation0"     [128]        128  #195                                                  ["resnetv15_stage2__plus0"]                                             ["resnetv15_stage2_conv3_fwd", "resnetv15_stage2__plus1"]
 "resnetv15_stage2_conv3_fwd"       [128]        128  Conv((3, 3), 128=>128)                                ["resnetv15_stage2_activation0"]                                        ["resnetv15_stage2_batchnorm3_fwd"]
 "resnetv15_stage2_batchnorm3_fwd"  [128]        128  BatchNorm(128, λ = relu)                              ["resnetv15_stage2_conv3_fwd"]                                          ["resnetv15_stage2_conv4_fwd"]
 "resnetv15_stage2_conv4_fwd"       [128]        128  Conv((3, 3), 128=>128)                                ["resnetv15_stage2_batchnorm3_fwd"]                                     ["resnetv15_stage2_batchnorm4_fwd"]
 "resnetv15_stage2_batchnorm4_fwd"  [128]        128  BatchNorm(128)                                        ["resnetv15_stage2_conv4_fwd"]                                          ["resnetv15_stage2__plus1"]
 "resnetv15_stage2__plus1"          [128, 128]   128  #225                                                  ["resnetv15_stage2_activation0", "resnetv15_stage2_batchnorm4_fwd"]     ["resnetv15_stage2_activation1"]
 "resnetv15_stage2_activation1"     [128]        128  #195                                                  ["resnetv15_stage2__plus1"]                                             ["resnetv15_stage3_conv2_fwd", "resnetv15_stage3_conv0_fwd"]
 "resnetv15_stage3_conv2_fwd"       [128]        256  Conv((1, 1), 128=>256)                                ["resnetv15_stage2_activation1"]                                        ["resnetv15_stage3_batchnorm2_fwd"]
 "resnetv15_stage3_batchnorm2_fwd"  [256]        256  BatchNorm(256)                                        ["resnetv15_stage3_conv2_fwd"]                                          ["resnetv15_stage3__plus0"]
 "resnetv15_stage3_conv0_fwd"       [128]        256  Conv((3, 3), 128=>256)                                ["resnetv15_stage2_activation1"]                                        ["resnetv15_stage3_batchnorm0_fwd"]
 "resnetv15_stage3_batchnorm0_fwd"  [256]        256  BatchNorm(256, λ = relu)                              ["resnetv15_stage3_conv0_fwd"]                                          ["resnetv15_stage3_conv1_fwd"]
 "resnetv15_stage3_conv1_fwd"       [256]        256  Conv((3, 3), 256=>256)                                ["resnetv15_stage3_batchnorm0_fwd"]                                     ["resnetv15_stage3_batchnorm1_fwd"]
 "resnetv15_stage3_batchnorm1_fwd"  [256]        256  BatchNorm(256)                                        ["resnetv15_stage3_conv1_fwd"]                                          ["resnetv15_stage3__plus0"]
 "resnetv15_stage3__plus0"          [256, 256]   256  #225                                                  ["resnetv15_stage3_batchnorm2_fwd", "resnetv15_stage3_batchnorm1_fwd"]  ["resnetv15_stage3_activation0"]
 "resnetv15_stage3_activation0"     [256]        256  #195                                                  ["resnetv15_stage3__plus0"]                                             ["resnetv15_stage3_conv3_fwd", "resnetv15_stage3__plus1"]
 "resnetv15_stage3_conv3_fwd"       [256]        256  Conv((3, 3), 256=>256)                                ["resnetv15_stage3_activation0"]                                        ["resnetv15_stage3_batchnorm3_fwd"]
 "resnetv15_stage3_batchnorm3_fwd"  [256]        256  BatchNorm(256, λ = relu)                              ["resnetv15_stage3_conv3_fwd"]                                          ["resnetv15_stage3_conv4_fwd"]
 "resnetv15_stage3_conv4_fwd"       [256]        256  Conv((3, 3), 256=>256)                                ["resnetv15_stage3_batchnorm3_fwd"]                                     ["resnetv15_stage3_batchnorm4_fwd"]
 "resnetv15_stage3_batchnorm4_fwd"  [256]        256  BatchNorm(256)                                        ["resnetv15_stage3_conv4_fwd"]                                          ["resnetv15_stage3__plus1"]
 "resnetv15_stage3__plus1"          [256, 256]   256  #225                                                  ["resnetv15_stage3_activation0", "resnetv15_stage3_batchnorm4_fwd"]     ["resnetv15_stage3_activation1"]
 "resnetv15_stage3_activation1"     [256]        256  #195                                                  ["resnetv15_stage3__plus1"]                                             ["resnetv15_stage4_conv2_fwd", "resnetv15_stage4_conv0_fwd"]
 "resnetv15_stage4_conv2_fwd"       [256]        512  Conv((1, 1), 256=>512)                                ["resnetv15_stage3_activation1"]                                        ["resnetv15_stage4_batchnorm2_fwd"]
 "resnetv15_stage4_batchnorm2_fwd"  [512]        512  BatchNorm(512)                                        ["resnetv15_stage4_conv2_fwd"]                                          ["resnetv15_stage4__plus0"]
 "resnetv15_stage4_conv0_fwd"       [256]        512  Conv((3, 3), 256=>512)                                ["resnetv15_stage3_activation1"]                                        ["resnetv15_stage4_batchnorm0_fwd"]
 "resnetv15_stage4_batchnorm0_fwd"  [512]        512  BatchNorm(512, λ = relu)                              ["resnetv15_stage4_conv0_fwd"]                                          ["resnetv15_stage4_conv1_fwd"]
 "resnetv15_stage4_conv1_fwd"       [512]        512  Conv((3, 3), 512=>512)                                ["resnetv15_stage4_batchnorm0_fwd"]                                     ["resnetv15_stage4_batchnorm1_fwd"]
 "resnetv15_stage4_batchnorm1_fwd"  [512]        512  BatchNorm(512)                                        ["resnetv15_stage4_conv1_fwd"]                                          ["resnetv15_stage4__plus0"]
 "resnetv15_stage4__plus0"          [512, 512]   512  #225                                                  ["resnetv15_stage4_batchnorm2_fwd", "resnetv15_stage4_batchnorm1_fwd"]  ["resnetv15_stage4_activation0"]
 "resnetv15_stage4_activation0"     [512]        512  #195                                                  ["resnetv15_stage4__plus0"]                                             ["resnetv15_stage4_conv3_fwd", "resnetv15_stage4__plus1"]
 "resnetv15_stage4_conv3_fwd"       [512]        512  Conv((3, 3), 512=>512)                                ["resnetv15_stage4_activation0"]                                        ["resnetv15_stage4_batchnorm3_fwd"]
 "resnetv15_stage4_batchnorm3_fwd"  [512]        512  BatchNorm(512, λ = relu)                              ["resnetv15_stage4_conv3_fwd"]                                          ["resnetv15_stage4_conv4_fwd"]
 "resnetv15_stage4_conv4_fwd"       [512]        512  Conv((3, 3), 512=>512)                                ["resnetv15_stage4_batchnorm3_fwd"]                                     ["resnetv15_stage4_batchnorm4_fwd"]
 "resnetv15_stage4_batchnorm4_fwd"  [512]        512  BatchNorm(512)                                        ["resnetv15_stage4_conv4_fwd"]                                          ["resnetv15_stage4__plus1"]
 "resnetv15_stage4__plus1"          [512, 512]   512  #225                                                  ["resnetv15_stage4_activation0", "resnetv15_stage4_batchnorm4_fwd"]     ["resnetv15_stage4_activation1"]
 "resnetv15_stage4_activation1"     [512]        512  #195                                                  ["resnetv15_stage4__plus1"]                                             ["resnetv15_pool1_fwd"]
 "resnetv15_pool1_fwd"              [512]        512  #127                                                  ["resnetv15_stage4_activation1"]                                        ["flatten_170"]
 "flatten_170"                      [512]        512  Flatten(-1)                                           ["resnetv15_pool1_fwd"]                                                 ["resnetv15_dense0_fwd"]
 "resnetv15_dense0_fwd"             [512]       1000  Dense(512, 1000)                                      ["flatten_170"]                                                         Any[]

All the methods above should have appropriate docstrings.

I have thought about how to make the CompGraph print nicely, but I can't think of any good way. If you have any ideas I'm all ears.

For examining the structure I usually just use netron on an exported model. It is also possible to extract the CompGraph as a LightGraph and use GraphPlot but the results do very seldom look good so I haven't bothered to advertise this possibility.

Hope this helps!

The layers that I'd like implemented are a subset of the operations that are supported (in fact, I'd really like to start with just a simple feedforward network with Gemm and Relu layers).

Send me a list of them and I'll add them when I find the time, or even better, try to make implement them yourself using the documentation (filing issues about what does not make sense in it) and send me a PR :)

vtjeng commented 3 years ago

I have thought about how to make the CompGraph print nicely, but I can't think of any good way. If you have any ideas I'm all ears.

You could go for something like one of the two top answers here https://stackoverflow.com/questions/42480111/model-summary-in-pytorch. It doesn't take care of networks with skip connections, but does give a sense for the types of layers contained within. I think you actually have something quite close to this in your ad-hoc print?

For examining the structure I usually just use netron on an exported model.

In my case I'm only importing models, so I want to visualize the model in Julia to make sure ONNXmutable 'got things right'. vertices will do the trick; I was just struggling to find this function. (I actually still can't figure out where it is defined in your source code in NaiveNASFlux / NaiveNASlib !)

Send me a list of them and I'll add them when I find the time, or even better, try to make implement them yourself using the documentation (filing issues about what does not make sense in it) and send me a PR :)

I think you've actually implemented everything I need, but I'll make an issue with any additional operations I need.

DrChainsaw commented 3 years ago

I think you actually have something quite close to this in your ad-hoc print?

Yes, I was thinking of making something like that the show or summary method (or even an own function) but I constantly change around what I want to see so I felt it would not be immediately useful for me. I guess it could be for people who did not write the library though...

vertices will do the trick; I was just struggling to find this function.

Ha, I realize now that it is only tangentially mentioned in the docs. I will advertise it much earlier as it is a pretty important function. When I think back it was because I initially though I would build and advertise the CompGraph as a LightGraph and vertices is a LightGraphs function.

I actually still can't figure out where it is defined in your source code in NaiveNASFlux / NaiveNASlib !

Here you go :)

The definition of flatten is right above.

I think you've actually implemented everything I need, but I'll make an issue with any additional operations I need.

Awesome!

DrChainsaw commented 3 years ago

make sure ONNXmutable 'got things right'.

Forgot to comment on this. Fwiw, there is a tool to compare the output of a CompGraph with an the output from onnxruntime that is being used in the unit tests. With a little bit of effort one can make use of it to make a final verification that it really was the same model that materialized (assuming onnxruntime got it right ofc :) ).

Not the most fun thing when they don't produce the same output, but if they don't then thats an issue here that becomes my problem :)

DrChainsaw commented 2 months ago

Better late than never I guess.

NaiveNASlib 2.1.0 uses PrettyTables to show a summary like the table above by default.

Default show image.png

The show just uses the more customizeable function graphsummary:

With graphsummary image.png