Closed vtjeng closed 2 months ago
Hi and thanks for showing interest!
The CompGraph has many methods for introspection and most of them are documented in NaiveNASflux.
Here is an example of the type of ad-hoc summarytable i usually use:
julia> cg = CompGraph("resnet18-v1-7.onnx");
# Return vertices of cg as a topologically sorted array
julia> vs = vertices(cg);
# This is due to an oversight in NaiveNASflux which I will correct one day, otherwise it will fail when it hits a non-Flux layer like the element wise additions of the resnet
julia> NaiveNASflux.layer(f) = f;
julia> [name.(vs) nin.(vs) nout.(vs) layer.(vs) map(ivs -> name.(ivs), inputs.(vs)) map(ovs -> name.(ovs), outputs.(vs))]
61×6 Matrix{Any}:
"data" Any[] 3 LayerTypeWrapper(FluxConv{2}()) Any[] ["resnetv15_conv0_fwd"]
"resnetv15_conv0_fwd" [3] 64 Conv((7, 7), 3=>64) ["data"] ["resnetv15_batchnorm0_fwd"]
"resnetv15_batchnorm0_fwd" [64] 64 BatchNorm(64, λ = relu) ["resnetv15_conv0_fwd"] ["resnetv15_pool0_fwd"]
"resnetv15_pool0_fwd" [64] 64 MaxPool((3, 3), pad = (1, 1, 1, 1), stride = (2, 2)) ["resnetv15_batchnorm0_fwd"] ["resnetv15_stage1_conv0_fwd", "resnetv15_stage1__plus0"]
"resnetv15_stage1_conv0_fwd" [64] 64 Conv((3, 3), 64=>64) ["resnetv15_pool0_fwd"] ["resnetv15_stage1_batchnorm0_fwd"]
"resnetv15_stage1_batchnorm0_fwd" [64] 64 BatchNorm(64, λ = relu) ["resnetv15_stage1_conv0_fwd"] ["resnetv15_stage1_conv1_fwd"]
"resnetv15_stage1_conv1_fwd" [64] 64 Conv((3, 3), 64=>64) ["resnetv15_stage1_batchnorm0_fwd"] ["resnetv15_stage1_batchnorm1_fwd"]
"resnetv15_stage1_batchnorm1_fwd" [64] 64 BatchNorm(64) ["resnetv15_stage1_conv1_fwd"] ["resnetv15_stage1__plus0"]
"resnetv15_stage1__plus0" [64, 64] 64 #225 ["resnetv15_pool0_fwd", "resnetv15_stage1_batchnorm1_fwd"] ["resnetv15_stage1_activation0"]
"resnetv15_stage1_activation0" [64] 64 #195 ["resnetv15_stage1__plus0"] ["resnetv15_stage1_conv2_fwd", "resnetv15_stage1__plus1"]
"resnetv15_stage1_conv2_fwd" [64] 64 Conv((3, 3), 64=>64) ["resnetv15_stage1_activation0"] ["resnetv15_stage1_batchnorm2_fwd"]
"resnetv15_stage1_batchnorm2_fwd" [64] 64 BatchNorm(64, λ = relu) ["resnetv15_stage1_conv2_fwd"] ["resnetv15_stage1_conv3_fwd"]
"resnetv15_stage1_conv3_fwd" [64] 64 Conv((3, 3), 64=>64) ["resnetv15_stage1_batchnorm2_fwd"] ["resnetv15_stage1_batchnorm3_fwd"]
"resnetv15_stage1_batchnorm3_fwd" [64] 64 BatchNorm(64) ["resnetv15_stage1_conv3_fwd"] ["resnetv15_stage1__plus1"]
"resnetv15_stage1__plus1" [64, 64] 64 #225 ["resnetv15_stage1_activation0", "resnetv15_stage1_batchnorm3_fwd"] ["resnetv15_stage1_activation1"]
"resnetv15_stage1_activation1" [64] 64 #195 ["resnetv15_stage1__plus1"] ["resnetv15_stage2_conv2_fwd", "resnetv15_stage2_conv0_fwd"]
"resnetv15_stage2_conv2_fwd" [64] 128 Conv((1, 1), 64=>128) ["resnetv15_stage1_activation1"] ["resnetv15_stage2_batchnorm2_fwd"]
"resnetv15_stage2_batchnorm2_fwd" [128] 128 BatchNorm(128) ["resnetv15_stage2_conv2_fwd"] ["resnetv15_stage2__plus0"]
"resnetv15_stage2_conv0_fwd" [64] 128 Conv((3, 3), 64=>128) ["resnetv15_stage1_activation1"] ["resnetv15_stage2_batchnorm0_fwd"]
"resnetv15_stage2_batchnorm0_fwd" [128] 128 BatchNorm(128, λ = relu) ["resnetv15_stage2_conv0_fwd"] ["resnetv15_stage2_conv1_fwd"]
"resnetv15_stage2_conv1_fwd" [128] 128 Conv((3, 3), 128=>128) ["resnetv15_stage2_batchnorm0_fwd"] ["resnetv15_stage2_batchnorm1_fwd"]
"resnetv15_stage2_batchnorm1_fwd" [128] 128 BatchNorm(128) ["resnetv15_stage2_conv1_fwd"] ["resnetv15_stage2__plus0"]
"resnetv15_stage2__plus0" [128, 128] 128 #225 ["resnetv15_stage2_batchnorm2_fwd", "resnetv15_stage2_batchnorm1_fwd"] ["resnetv15_stage2_activation0"]
"resnetv15_stage2_activation0" [128] 128 #195 ["resnetv15_stage2__plus0"] ["resnetv15_stage2_conv3_fwd", "resnetv15_stage2__plus1"]
"resnetv15_stage2_conv3_fwd" [128] 128 Conv((3, 3), 128=>128) ["resnetv15_stage2_activation0"] ["resnetv15_stage2_batchnorm3_fwd"]
"resnetv15_stage2_batchnorm3_fwd" [128] 128 BatchNorm(128, λ = relu) ["resnetv15_stage2_conv3_fwd"] ["resnetv15_stage2_conv4_fwd"]
"resnetv15_stage2_conv4_fwd" [128] 128 Conv((3, 3), 128=>128) ["resnetv15_stage2_batchnorm3_fwd"] ["resnetv15_stage2_batchnorm4_fwd"]
"resnetv15_stage2_batchnorm4_fwd" [128] 128 BatchNorm(128) ["resnetv15_stage2_conv4_fwd"] ["resnetv15_stage2__plus1"]
"resnetv15_stage2__plus1" [128, 128] 128 #225 ["resnetv15_stage2_activation0", "resnetv15_stage2_batchnorm4_fwd"] ["resnetv15_stage2_activation1"]
"resnetv15_stage2_activation1" [128] 128 #195 ["resnetv15_stage2__plus1"] ["resnetv15_stage3_conv2_fwd", "resnetv15_stage3_conv0_fwd"]
"resnetv15_stage3_conv2_fwd" [128] 256 Conv((1, 1), 128=>256) ["resnetv15_stage2_activation1"] ["resnetv15_stage3_batchnorm2_fwd"]
"resnetv15_stage3_batchnorm2_fwd" [256] 256 BatchNorm(256) ["resnetv15_stage3_conv2_fwd"] ["resnetv15_stage3__plus0"]
"resnetv15_stage3_conv0_fwd" [128] 256 Conv((3, 3), 128=>256) ["resnetv15_stage2_activation1"] ["resnetv15_stage3_batchnorm0_fwd"]
"resnetv15_stage3_batchnorm0_fwd" [256] 256 BatchNorm(256, λ = relu) ["resnetv15_stage3_conv0_fwd"] ["resnetv15_stage3_conv1_fwd"]
"resnetv15_stage3_conv1_fwd" [256] 256 Conv((3, 3), 256=>256) ["resnetv15_stage3_batchnorm0_fwd"] ["resnetv15_stage3_batchnorm1_fwd"]
"resnetv15_stage3_batchnorm1_fwd" [256] 256 BatchNorm(256) ["resnetv15_stage3_conv1_fwd"] ["resnetv15_stage3__plus0"]
"resnetv15_stage3__plus0" [256, 256] 256 #225 ["resnetv15_stage3_batchnorm2_fwd", "resnetv15_stage3_batchnorm1_fwd"] ["resnetv15_stage3_activation0"]
"resnetv15_stage3_activation0" [256] 256 #195 ["resnetv15_stage3__plus0"] ["resnetv15_stage3_conv3_fwd", "resnetv15_stage3__plus1"]
"resnetv15_stage3_conv3_fwd" [256] 256 Conv((3, 3), 256=>256) ["resnetv15_stage3_activation0"] ["resnetv15_stage3_batchnorm3_fwd"]
"resnetv15_stage3_batchnorm3_fwd" [256] 256 BatchNorm(256, λ = relu) ["resnetv15_stage3_conv3_fwd"] ["resnetv15_stage3_conv4_fwd"]
"resnetv15_stage3_conv4_fwd" [256] 256 Conv((3, 3), 256=>256) ["resnetv15_stage3_batchnorm3_fwd"] ["resnetv15_stage3_batchnorm4_fwd"]
"resnetv15_stage3_batchnorm4_fwd" [256] 256 BatchNorm(256) ["resnetv15_stage3_conv4_fwd"] ["resnetv15_stage3__plus1"]
"resnetv15_stage3__plus1" [256, 256] 256 #225 ["resnetv15_stage3_activation0", "resnetv15_stage3_batchnorm4_fwd"] ["resnetv15_stage3_activation1"]
"resnetv15_stage3_activation1" [256] 256 #195 ["resnetv15_stage3__plus1"] ["resnetv15_stage4_conv2_fwd", "resnetv15_stage4_conv0_fwd"]
"resnetv15_stage4_conv2_fwd" [256] 512 Conv((1, 1), 256=>512) ["resnetv15_stage3_activation1"] ["resnetv15_stage4_batchnorm2_fwd"]
"resnetv15_stage4_batchnorm2_fwd" [512] 512 BatchNorm(512) ["resnetv15_stage4_conv2_fwd"] ["resnetv15_stage4__plus0"]
"resnetv15_stage4_conv0_fwd" [256] 512 Conv((3, 3), 256=>512) ["resnetv15_stage3_activation1"] ["resnetv15_stage4_batchnorm0_fwd"]
"resnetv15_stage4_batchnorm0_fwd" [512] 512 BatchNorm(512, λ = relu) ["resnetv15_stage4_conv0_fwd"] ["resnetv15_stage4_conv1_fwd"]
"resnetv15_stage4_conv1_fwd" [512] 512 Conv((3, 3), 512=>512) ["resnetv15_stage4_batchnorm0_fwd"] ["resnetv15_stage4_batchnorm1_fwd"]
"resnetv15_stage4_batchnorm1_fwd" [512] 512 BatchNorm(512) ["resnetv15_stage4_conv1_fwd"] ["resnetv15_stage4__plus0"]
"resnetv15_stage4__plus0" [512, 512] 512 #225 ["resnetv15_stage4_batchnorm2_fwd", "resnetv15_stage4_batchnorm1_fwd"] ["resnetv15_stage4_activation0"]
"resnetv15_stage4_activation0" [512] 512 #195 ["resnetv15_stage4__plus0"] ["resnetv15_stage4_conv3_fwd", "resnetv15_stage4__plus1"]
"resnetv15_stage4_conv3_fwd" [512] 512 Conv((3, 3), 512=>512) ["resnetv15_stage4_activation0"] ["resnetv15_stage4_batchnorm3_fwd"]
"resnetv15_stage4_batchnorm3_fwd" [512] 512 BatchNorm(512, λ = relu) ["resnetv15_stage4_conv3_fwd"] ["resnetv15_stage4_conv4_fwd"]
"resnetv15_stage4_conv4_fwd" [512] 512 Conv((3, 3), 512=>512) ["resnetv15_stage4_batchnorm3_fwd"] ["resnetv15_stage4_batchnorm4_fwd"]
"resnetv15_stage4_batchnorm4_fwd" [512] 512 BatchNorm(512) ["resnetv15_stage4_conv4_fwd"] ["resnetv15_stage4__plus1"]
"resnetv15_stage4__plus1" [512, 512] 512 #225 ["resnetv15_stage4_activation0", "resnetv15_stage4_batchnorm4_fwd"] ["resnetv15_stage4_activation1"]
"resnetv15_stage4_activation1" [512] 512 #195 ["resnetv15_stage4__plus1"] ["resnetv15_pool1_fwd"]
"resnetv15_pool1_fwd" [512] 512 #127 ["resnetv15_stage4_activation1"] ["flatten_170"]
"flatten_170" [512] 512 Flatten(-1) ["resnetv15_pool1_fwd"] ["resnetv15_dense0_fwd"]
"resnetv15_dense0_fwd" [512] 1000 Dense(512, 1000) ["flatten_170"] Any[]
All the methods above should have appropriate docstrings.
I have thought about how to make the CompGraph print nicely, but I can't think of any good way. If you have any ideas I'm all ears.
For examining the structure I usually just use netron on an exported model. It is also possible to extract the CompGraph as a LightGraph
and use GraphPlot but the results do very seldom look good so I haven't bothered to advertise this possibility.
Hope this helps!
The layers that I'd like implemented are a subset of the operations that are supported (in fact, I'd really like to start with just a simple feedforward network with Gemm and Relu layers).
Send me a list of them and I'll add them when I find the time, or even better, try to make implement them yourself using the documentation (filing issues about what does not make sense in it) and send me a PR :)
I have thought about how to make the CompGraph print nicely, but I can't think of any good way. If you have any ideas I'm all ears.
You could go for something like one of the two top answers here https://stackoverflow.com/questions/42480111/model-summary-in-pytorch. It doesn't take care of networks with skip connections, but does give a sense for the types of layers contained within. I think you actually have something quite close to this in your ad-hoc print?
For examining the structure I usually just use netron on an exported model.
In my case I'm only importing models, so I want to visualize the model in Julia to make sure ONNXmutable 'got things right'. vertices
will do the trick; I was just struggling to find this function. (I actually still can't figure out where it is defined in your source code in NaiveNASFlux
/ NaiveNASlib
!)
Send me a list of them and I'll add them when I find the time, or even better, try to make implement them yourself using the documentation (filing issues about what does not make sense in it) and send me a PR :)
I think you've actually implemented everything I need, but I'll make an issue with any additional operations I need.
I think you actually have something quite close to this in your ad-hoc print?
Yes, I was thinking of making something like that the show
or summary
method (or even an own function) but I constantly change around what I want to see so I felt it would not be immediately useful for me. I guess it could be for people who did not write the library though...
vertices will do the trick; I was just struggling to find this function.
Ha, I realize now that it is only tangentially mentioned in the docs. I will advertise it much earlier as it is a pretty important function. When I think back it was because I initially though I would build and advertise the CompGraph as a LightGraph and vertices
is a LightGraphs function.
I actually still can't figure out where it is defined in your source code in NaiveNASFlux / NaiveNASlib !
The definition of flatten
is right above.
I think you've actually implemented everything I need, but I'll make an issue with any additional operations I need.
Awesome!
make sure ONNXmutable 'got things right'.
Forgot to comment on this. Fwiw, there is a tool to compare the output of a CompGraph with an the output from onnxruntime that is being used in the unit tests. With a little bit of effort one can make use of it to make a final verification that it really was the same model that materialized (assuming onnxruntime got it right ofc :) ).
Not the most fun thing when they don't produce the same output, but if they don't then thats an issue here that becomes my problem :)
Better late than never I guess.
NaiveNASlib 2.1.0 uses PrettyTables to show a summary like the table above by default.
The show just uses the more customizeable function graphsummary
:
I was hoping to use
ONNXmutable.jl
to extract theFlux
operations specified in an.onnx
file. Specifically, I'd like to be able to provide an.onnx
file and load in a function corresponding to a composition ofFlux
primitives - something likeWhat I've tried
CompGraph
(or aONNX.ModelProto()
usingextract
).CompGraph
to understand what Flux operations are contained in theCompGraph
.Use Case
JuMP
variables (or regular floats), producing the appropriate output (appropriately constrainedJuMP
variables ifJuMP
variables are passed in, and regular floats corresponding to forward propagation otherwise).onnx
files (with a subset of supported operations)Flux
layers [1] to replace my custom layers, and implement functionality that enables theseFlux
layers to acceptJuMP
variables and produce the appropriate output.[1] I'm using
Flux
because it seems to be the best supported framework related to neural networks in Julia, but would be open to suggestions to consider other options.Additional Notes
The layers that I'd like implemented are a subset of the operations that are supported (in fact, I'd really like to start with just a simple feedforward network with
Gemm
andRelu
layers).