Xilinx / finn-base

Open Source Compiler Framework using ONNX as Frontend and IR
https://finn-base.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
29 stars 17 forks source link

fix_float64 not passed if cleanup is run #63

Open jmitrevs opened 2 years ago

jmitrevs commented 2 years ago

In the transform member function of ModelWrapper:

https://github.com/Xilinx/finn-base/blob/2f10c0c1f11cc483f27794887c41030bdd0dfab5/src/finn/core/modelwrapper.py#L122-L146

if cleanup=True, then the transform is called with fix_float64=True during the cleanup, no matter what the transform was originally called with. I think either fix_float64 should default to False, or fix_float64 should be passed to the cleanup.

jmitrevs commented 2 years ago

A simple onnx file that shows this behavior is hare:

https://drive.google.com/file/d/16Opd4K3xjEEp9jgmmkZgqW5fXCFJ84lA/view?usp=sharing

(github doesn't allow me to upload onnx files)

The standard way to initiate this problem is with:

    qonnx.util.cleanup.cleanup("MLP.onnx", out_file="MLP_clean.onnx")
    model = ModelWrapper("MLP_clean.onnx")
    X = create_predict_data(model)
    idict = {model.graph.input[0].name: X}
    y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name]

where create_predict_data basically crates x = np.random.uniform(-1, 1, (1, 259)).

jmitrevs commented 2 years ago

I can get the above version of this model to work by setting the default of fix_float64 to false, but interestingly, the same steps do not work for a float version of this model:

https://drive.google.com/file/d/1Ogjh7H3X8hpT3qktXs-11lVlZjpDCh_Y/view?usp=sharing

For some reason, I get Found unspecified tensor shapes, try infer_shapes even though I run the model cleaning first.

maltanar commented 2 years ago

Thanks for providing the example models, I can reproduce the issue and I'm looking into a clean solution for this. The origin of the fix_float64 (which indistriminately replaces float64 initializers with float32 ones) were some erronous ONNX models that had standard ONNX nodes with mismatched input data types, e.g. an Add node with one float32 and one float64 input. This violates the ONNX standard and onnxruntime complains when trying to execute them, so we put in a quick-and-dirty fix that replaces float64 initializers with float32 ones.

If I remove the workaround completely, we get numerous test failures in FINN. However, for models that actually use float64 for both inputs of a standard node, the workaround actually creates the original problem itself :-) so a cleaner solution could be one of:

1) Identify the source of the invalid ONNX models, fix it and remove the fix_float64 workaround completely. This is what I'm looking into currently. At least one source of float64s seems to be the QONNX-to-FINN-ONNX conversion by @HenniOVP e.g. https://github.com/Xilinx/finn/blob/dev/src/finn/transformation/qonnx/qonnx_activation_handlers.py#L193-L194 - @HenniOVP could you comment on this? ONNX float attributes seem to be actually only allowed to be float32 according to the protobuf spec: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto#L118

2) Do the float64 fixing more intelligently instead of changing all initializers, e.g. only performing the conversion if the input types are mismatched. However, this will likely require paying attention to the node type (there may be ONNX nodes that is actually valid with mixed precision).

3) Propagate the fix_float64 across multiple calls including the ModelWrapper cleanup and the qonnx cleanup

maltanar commented 2 years ago

I think I found the answer for the question I asked above to @HenniOVP : it's actually not the case that ONNX only accepts 64bit floats as attributes, but rather, the protobuf wrapper for ONNX doesn't recognize numpy float32 (as opposed to native Python float). Working on a fix for that part.

heborras commented 2 years ago

I was just about to comment something similar. It seems that the issue, which I mentioned in the code only appears with the attributes and not the initializes. So in line 194 everything works fine, when using float32 instead, but in lines 168 to 170 the onnx numpy helper will complain, when using float32 and thus required float64.

So, in fact line 194 should probably be changed to float32.