Closed lambe closed 3 months ago
I had to check the code myself, but it seems that serialization of Flux built in global pools is not implemented. Iirc the deserialization implementation as well as the statement that it is supported is from before Flux had its own global pools and just told users to build their own (e.g. from Max/MeanPool).
It should be pretty easy to add though. Just add the methods somewhere here and just directly call the functions below. Just be mindful if there is a difference between what Flux does and what ONNX thinks the global pools does described somewhere here, for example w.r.t dropping the spatial dimensions.
The second argument to the globalmeanpool
and globalmaxpool
can be used to account for this difference (identity
should work if there is no difference).
Ok, I've added the appropriate methods and the script in the description is able to save and load the model. See #88
There's still a warning thrown
┌ Warning: No valid input sizes provided. Shape inference could not be done. Either provide Integer insizes manually or use load(...; infer_shapes=false) to disable. If disabled, graph mutation might not work.
└ @ ONNXNaiveNASflux ~/toolpath/ONNXNaiveNASflux.jl/src/deserialize/infershape.jl:47
and the @assert
statement fails, so I'm going to see if updating deserialize/ops.jl
is needed.
Update: the @assert
statement should be failing since test_model
is a flux model type, but new_model
is an ONNX-style computational graph.
However, some good news, running test_tensor
through new_model
results in an approximately identical answer. (Add the following code block to the test script in the description.)
import ONNXRunTime as ORT
function run_ort(nn_ort, x)
x_t = permutedims(x, (4, 3, 2, 1))
inp = Dict(only(nn_ort.input_names) => x_t)
out = only(values(nn_ort(inp)))
permutedims(out, reverse(1:ndims(out)))
end
ort_new = ORT.load_inference("test.onnx")
new_out_tensor = run_ort(ort_new, test_tensor)
@assert isapprox(new_out_tensor, out_tensor)
I'm happy with this solution, but interested to know if that warning can be addressed before merging the PR.
Took a look at the warning and it is correct to warn here so no action is needed.
The reason for the warning is this:
Conv((1,1), 10 => 10))
does not need the first two dimensions to have any particular size) when saving the model. The number of channels is correctly inferred though. Flux.outputshape
to infer all input sizes when they are not given as input when loading. This function throws an exception if the size of any dimension is missing or 0. It will also throw is sizes don't line up (e.g. if someone does flatten or reshape without global pooling first) so we can't just guess on some size. Instead it is checked that all sizes are >0 and if not the attempt at shape inference is abandoned and the warning is printed.To avoid the warning, you can just supply the input sizes when loading the model:
ONNXNaiveNASflux.load("test.onnx", size(test_tensor))
or when saving:
ONNXNaiveNASflux.save("test.onnx", test_model, size(test_tensor))
The input sizes are only used when using the NaiveNASlib features to do parameter pruning or other NAS-like things which is why the model seems to work just fine for inference (and should work fine for training too).
There are only a few op-types which need the size info so chances are that the NaiveNASlib stuff would work as well even if input sizes are not inferred. Since it would be quite difficult to understand the reason for the error you get if you try to change the dimension of some parameter when this info is missing I decided to always print the warning by default.
Great, thanks for the context! I'll update my calling code with a size parameter.
Fixed in #88
Here's a simple script to test exporting and importing a model
Error returned:
GlobalAveragePool is listed as a supported operation in the docs, but it doesn't seem to work for me. Is it a name change issue?