Xilinx / finn

Dataflow compiler for QNN inference on FPGAs
https://xilinx.github.io/finn
BSD 3-Clause "New" or "Revised" License
708 stars 225 forks source link

FINN does not handle a model with concatenation layer correctly #853

Closed williamliao28 closed 1 year ago

williamliao28 commented 1 year ago

name: Bug report about: Something isn't working as expected title: 'FINN does not handle a model with concatenation layer correctly' labels: bug assignees: ''


Prerequisites

Quick summary

FINN does not handle a model with concatenation layer correctly.

Details

Any suggestions are welcomed. I have tried the following experiments for this issue.

First, I have looked into the source code of the InferConcatLayer() function (line 1678 to 1737 in convert_to_hls_layers.py). I think this function is used to convert the Concat ONNX layer to the HLS layer StreamingConcat. I found that it already calls InferDataTypes() on line 1736. So is it necessary to call another InferDataTypes() again after calling InferConcatLayer()?

Second, I try to call the InferConcatLayer() function before InferBinaryMatrixVectorActivation() and InferQuantizedMatrixVectorActivation(). However, I get the following error message: AssertionError: Thresholds in MatrixVectorActivation_MatMul_4 can't be expressed with type INT10. The traceback message is included in the screenshot posted below.

finn_traceback

Steps to Reproduce

I include two example python codes in the links below which may be useful for reproducing the error. The file simple_test.py contains the PyTorch module defining the model with Brevitas quantization. The file simple_test_finn_synthesis.py is the code I run inside the FINN docker and this one produce the error message described above. The file simple_test_cifar10_8bit_2epoch.onnx is a pertained model which can be used directly as the input for simple_test_finn_synthesis.py.

simple_test.py

simple_test_finn_synthesis.py

simple_test_cifar10_8bit_2epoch.onnx

  1. Clone the FINN repository
  2. Checkout the main branch
  3. Start the docker container with the command: bash ./run_docker.sh
  4. Put simple_test_cifar10_8bit_2epoch.onnx under the directory build/quant_model/simple_test_cifar10/lr0.02
  5. Run python simple_test_finn_synthesis.py

Expected behavior

The code simple_test_finn_synthesis.py should be run without any errors.

Actual behavior

Running python simple_test_finn_synthesis.py generates the error described above.

Originally posted by @williamliao28 in https://github.com/Xilinx/finn/discussions/845

auphelia commented 1 year ago

Thanks @williamliao28 for turning this into an issue,

Could you also let me know which version of FINN you're using?

auphelia commented 1 year ago

It would be best to use the provided template: https://github.com/Xilinx/finn/blob/main/.github/ISSUE_TEMPLATE/bug_report.md

williamliao28 commented 1 year ago

@auphelia I have edited it using the template. I am using the main branch of the FINN repo.

auphelia commented 1 year ago

Hi @williamliao28 ,

I have tested your use case with the current dev branch but was not able to reproduce the error. Your script passes without a problem. I will also test it with main to understand better where the error might have come from. Until then, could you switch to using dev instead?

There was one line in the script I needed to change to make it run with dev which was the export: import torch bo.export_finn_onnx(totensor_pyt, torch.randn(ishape), chkpt_preproc_name)

You might have seen that we moved to QONNX export only (https://github.com/Xilinx/finn/discussions/850), with the current dev the finn_onnx export should still be usable for now.

auphelia commented 1 year ago

I've inspected the generated .onnx files in a bit more detail and actually the switch to dev doesn't solve the problem. The reason why there is no error thrown is just because the accumulator width minimization is not run by default anymore.

The error is coming from the datatype inference for Concat. Concat is a standard onnx layer and so the FINN datatype doesn't get propagated correctly, I will look into a possible patch for that. For now, you can work around the problem by adding one additional streamlining transformation at the hls conversion step in your script.

# convert concat layers
model = model.transform(to_hls.InferConcatLayer())
model.save(build_dir + "/simple_test_hls1.onnx")
# Here is the additional streamlining transformation inserted
model = model.transform(RoundAndClipThresholds())
model = model.transform(to_hls.InferBinaryMatrixVectorActivation(mem_mode))
model.save(build_dir + "/simple_test_hls2.onnx")
model = model.transform(to_hls.InferQuantizedMatrixVectorActivation(mem_mode))
model.save(build_dir + "/simple_test_hls3.onnx")

The problem is that for the transformation RoundAndClipThresholds the input datatype needs to be an integer. But for MultiThreshold_6 the input datatype doesn't have the correct type associated because the datatype inference doesn't get correctly propagated through the standard concat layer. After converting it to StreamingConcat, the data type inference works correctly and the floating point thresholds of MultiThreshold_6 can be rounded to integer values, which then corrects the rest of the finn-flow.

Since I suggested to move to dev, I would like to add that the minimization of the accumulator width is not run by default anymore and if you would like to make use of this optimization, you will need to add this transformation at the end of your script (after the hls conversion): model = model.transform(MinimizeAccumulatorWidth())

auphelia commented 1 year ago

Hi @williamliao28 ,

Thanks for raising this issue! I will link this conversation in the GitHub discussion you started initially.

We merged a patch in qonnx and updated the latest qonnx version in dev, please try it out and let me know if there are any problems. Closing this issue for now, but please feel free to re-open it if necessary.

One remark because I had a look at your generated files, you should call a few additional streamlining functions to get rid of the transpose nodes (e.g. AbsorbConsecutiveTransposes) and also to make the necessary data layout conversions, like MakeMaxPoolNHWC.