DrChainsaw / ONNXNaiveNASflux.jl

Import/export ONNX models
MIT License
44 stars 2 forks source link

Error exporting a model with Flux.GlobalMeanPool() operation #87

Closed lambe closed 3 months ago

lambe commented 1 year ago

Here's a simple script to test exporting and importing a model

using Flux
using ONNXNaiveNASflux

test_model = Chain(
    Conv((1, 1), 10 => 10),
    GlobalMeanPool(),
    Flux.MLUtils.flatten,
    Dense(10, 10, relu),
    Dense(10, 10, relu),
)

test_tensor = rand(Float32, 5, 5, 10, 4)
out_tensor = test_model(test_tensor)
println("out_tensor size: ", size(out_tensor))  # Expected output size (10, 4)

# Export the model
ONNXNaiveNASflux.save("test.onnx", test_model)

# Import the model
new_model = ONNXNaiveNASflux.load("test.onnx")
@assert new_model == test_model

Error returned:

ERROR: LoadError: MethodError: no method matching size(::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#306#307"{typeof(ONNXNaiveNASflux.genname), Set{String}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, Tuple{Missing, Missing, Int64, Missing}})

Closest candidates are:
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:582
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:581
  size(::Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRPackedQ})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:585
  ...

Stacktrace:
  [1] (::GlobalMeanPool)(x::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#306#307"{typeof(ONNXNaiveNASflux.genname), Set{String}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, Tuple{Missing, Missing, Int64, Missing}})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/conv.jl:631
  [2] macro expansion
    @ ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:53 [inlined]
  [3] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}, x::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#306#307"{typeof(ONNXNaiveNASflux.genname), Set{String}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, Tuple{Missing, Missing, Int64, Missing}})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:53
  [4] (::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}})(x::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#306#307"{typeof(ONNXNaiveNASflux.genname), Set{String}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, Tuple{Missing, Missing, Int64, Missing}})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:51
  [5] graphproto(f::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}, indata::Pair{String, Tuple{Missing, Missing, Int64, Missing}}; namestrat::Function)
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:200
  [6] modelproto(f::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}, indata::Pair{String, Tuple{Missing, Missing, Int64, Missing}}; modelname::String, namestrat::Function, posthook::typeof(ONNXNaiveNASflux.validate), kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:45
  [7] modelproto(f::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}, inshapes::Tuple{Missing, Missing, Int64, Missing}; kwargs::Base.Pairs{Symbol, String, Tuple{Symbol}, NamedTuple{(:modelname,), Tuple{String}}})
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:42
  [8] modelproto(f::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}; kwargs::Base.Pairs{Symbol, String, Tuple{Symbol}, NamedTuple{(:modelname,), Tuple{String}}})
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:41
  [9] modelproto
    @ ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:41 [inlined]
 [10] #save#308
    @ ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:10 [inlined]
 [11] save(::String, ::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}})
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:10
 [12] top-level scope
    @ ~/toolpath/AlphaToolpath/examples/globalmeanpool_test.jl:19
 [13] include(fname::String)
    @ Base.MainInclude ./client.jl:478
 [14] top-level scope
    @ REPL[1]:1
 [15] top-level scope
    @ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52
in expression starting at /home/ablambe/toolpath/AlphaToolpath/examples/globalmeanpool_test.jl:19

GlobalAveragePool is listed as a supported operation in the docs, but it doesn't seem to work for me. Is it a name change issue?

DrChainsaw commented 1 year ago

I had to check the code myself, but it seems that serialization of Flux built in global pools is not implemented. Iirc the deserialization implementation as well as the statement that it is supported is from before Flux had its own global pools and just told users to build their own (e.g. from Max/MeanPool).

It should be pretty easy to add though. Just add the methods somewhere here and just directly call the functions below. Just be mindful if there is a difference between what Flux does and what ONNX thinks the global pools does described somewhere here, for example w.r.t dropping the spatial dimensions.

The second argument to the globalmeanpool and globalmaxpool can be used to account for this difference (identity should work if there is no difference).

lambe commented 1 year ago

Ok, I've added the appropriate methods and the script in the description is able to save and load the model. See #88

There's still a warning thrown

┌ Warning: No valid input sizes provided. Shape inference could not be done. Either provide Integer insizes manually or use load(...; infer_shapes=false) to disable. If disabled, graph mutation might not work.
└ @ ONNXNaiveNASflux ~/toolpath/ONNXNaiveNASflux.jl/src/deserialize/infershape.jl:47

and the @assert statement fails, so I'm going to see if updating deserialize/ops.jl is needed.

lambe commented 1 year ago

Update: the @assert statement should be failing since test_model is a flux model type, but new_model is an ONNX-style computational graph.

However, some good news, running test_tensor through new_model results in an approximately identical answer. (Add the following code block to the test script in the description.)

import ONNXRunTime as ORT

function run_ort(nn_ort, x)
    x_t = permutedims(x, (4, 3, 2, 1))
    inp = Dict(only(nn_ort.input_names) => x_t)
    out = only(values(nn_ort(inp)))
    permutedims(out, reverse(1:ndims(out)))
end

ort_new = ORT.load_inference("test.onnx")
new_out_tensor = run_ort(ort_new, test_tensor)
@assert isapprox(new_out_tensor, out_tensor)

I'm happy with this solution, but interested to know if that warning can be addressed before merging the PR.

DrChainsaw commented 1 year ago

Took a look at the warning and it is correct to warn here so no action is needed.

The reason for the warning is this:

  1. The height and width of the input can't be inferred from the first layer type alone (e.g. Conv((1,1), 10 => 10)) does not need the first two dimensions to have any particular size) when saving the model. The number of channels is correctly inferred though.
  2. ONNXNaiveNASflux uses Flux.outputshape to infer all input sizes when they are not given as input when loading. This function throws an exception if the size of any dimension is missing or 0. It will also throw is sizes don't line up (e.g. if someone does flatten or reshape without global pooling first) so we can't just guess on some size. Instead it is checked that all sizes are >0 and if not the attempt at shape inference is abandoned and the warning is printed.

To avoid the warning, you can just supply the input sizes when loading the model:

ONNXNaiveNASflux.load("test.onnx", size(test_tensor))

or when saving:

ONNXNaiveNASflux.save("test.onnx", test_model, size(test_tensor))

The input sizes are only used when using the NaiveNASlib features to do parameter pruning or other NAS-like things which is why the model seems to work just fine for inference (and should work fine for training too).

There are only a few op-types which need the size info so chances are that the NaiveNASlib stuff would work as well even if input sizes are not inferred. Since it would be quite difficult to understand the reason for the error you get if you try to change the dimension of some parameter when this info is missing I decided to always print the warning by default.

lambe commented 1 year ago

Great, thanks for the context! I'll update my calling code with a size parameter.

DrChainsaw commented 3 months ago

Fixed in #88