DrChainsaw / ONNXNaiveNASflux.jl

Import/export ONNX models
MIT License
43 stars 2 forks source link

input shapes for combined chains #75

Closed jbernalr closed 1 year ago

jbernalr commented 1 year ago

Is it possible to save a onnx model like this? which is set together in a Flux.Chain object? This is currently not working because I have failed to provide matching inputshapes. This would be the second question: What are the input shapes of a combined chain like this?

chain = Flux.Chain(Dense(1, 6, σ), Dense(6, 1)) 
...
save('my model.onnx', chain)
DrChainsaw commented 1 year ago

Hi,

It looks like the main problem is that Sigmoid is not supported by ONNXNaiveNASflux yet. Fortunately it is trivial to add and I will add it in #76

In the meantime, you can add support for it yourself by just running this:

Flux.σ(pp::ONNXNaiveNASflux.AbstractProbe) = ONNXNaiveNASflux.attribfun(identity, "Sigmoid", pp)

As for input shapes, they should be inferred correctly from Chain in most cases.

If you want to supply them yourself they shall be supplied as Flux wants them, for example save("mymodel.onnx", chain(1, :BatchSize)) or save("mymodel.onnx", chain, (1, missing)).

jbernalr commented 1 year ago

@DrChainsaw thank you! that was a quick fix. I am able to export the model both using the inferred inputshapes and missing:

save('my model.onnx', chain) save("mymodel.onnx", chain, (1, missing))

however this fails:

save("mymodel.onnx", chain(1, :BatchSize))

error:

ERROR: LoadError: MethodError: no method matching (::Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})(::Int64, ::Symbol)
Closest candidates are:
  (::Chain)(::Any) at ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:5

For the successful cases I tried to import the model in Matlab and noticed that the model is exported with size '(?, 1)' which causes an error in Matlab. My question here is if it is right that the model is exported with and undetermined size?

just for completeness this is the error trown, which I think I can fix by being more specific when calling the import function.

Unable to import the network because of the following issues:

1 operator(s)   :   Unable to create an input layer for ONNX input #1 (with name 'data_0')
because its data format is unknown or not supported as a MATLAB input layer. If you know the
input format, pass it by using the "InputDataFormats" parameter.
 The input shape declared in the ONNX file is '(?, 1)'.
1 operator(s)   :   Unable to create an output layer for ONNX network output #1 (with name
'dense_1') because its data format is unknown or not supported as a MATLAB output layer. If
you know the output format, pass it using the 'OutputDataFormats' parameter.
DrChainsaw commented 1 year ago

Great that it was just the Sigmoid. Most ops are trivial to add, but I'm still only adding them when someone needs them in an attempt to reduce future maintenance burden if/when I need to refactor stuff.

I made a copy-paste error for the input size which fails, it should be save("mymodel.onnx", chain, (1, :BatchSize)) (note the comma after chain).

About the matlab error and the (?, 1) size:

It should be correct. The question mark is the batch size which can't be inferred from the model itself and which is also generally not a fixed number for a given model except in rare circumstances.

Iirc, ONNX supports three different way to specify an input size along any dimension: 1) As a number, stating exactly what it is, like 1 in your model 2) As a variable with a name, which is what :BatchSize will result in 3) Unspecified, which is what missing will do (and this is also the default for sizes which can't be inferred like the batch size)

ONNX is a bit of a sprawling spec, and implementations having only partial coverage of it seems to be the norm. Here it appears like matlab does not support option 3 (which it seems to print as a ?).

You can either try with 2 (i.e the one with :BatchSize) or you can just put some number for the batch size (maybe you have one you intend to use).

As you said, using InputDataFormats is probably a perfectly fine way to get matlab to import the model and it is probably there just for this reason.

Fwiw, ONNXNaiveNASflux also lets the user override any model shapes in the load function and it is there just for the case when someone tries to load a model with some inputs specification it does not yet support.

DrChainsaw commented 1 year ago

Sorry, got a bit distracted but now I have merged the sigmoid and will release soon. Please reopen if you still have issues to load the model.