fastmachinelearning / hls4ml

Machine learning on FPGAs using HLS
https://fastmachinelearning.org/hls4ml
Apache License 2.0
1.24k stars 402 forks source link

CNNs with binary inputs and weights need fixes #749

Closed jmitrevs closed 1 year ago

jmitrevs commented 1 year ago

Description

As shown by #740, CNNs with binary quantizers don't currently work properly. This PR attempts to fix it.

Type of change

Tests

A pytest is added. However, it still shows errors in streaming.

Checklist

jmitrevs commented 1 year ago

Before I pushed 7186346 the test results were:

test_binary_cnn.py::test_model2[io_parallel-Vivado] PASSED                                                    [ 16%]
test_binary_cnn.py::test_model2[io_stream-Quartus] FAILED                                                     [ 33%]
test_binary_cnn.py::test_model2[io_stream-Vivado] FAILED                                                      [ 50%]
test_binary_cnn.py::test_model2[io_stream-Vitis] FAILED                                                       [ 66%]
test_binary_cnn.py::test_model2[io_parallel-Quartus] FAILED                                                   [ 83%]
test_binary_cnn.py::test_model2[io_parallel-Vitis] PASSED                                                     [100%]

After, they all fail. So it was not a successful fix.

jmitrevs commented 1 year ago

This is not complete, and the final test has been disabled, but I think the fixes that we have here should be added before we make 0.7.0 RC. The other stuff we can add after the RC is built or later.

The remaining to-dos:

  1. Max Pooling does not propagate XnorPrecisionType from the input to the output, so the logic is confused. The output zeros are no longer considered to be -1 but just 0.
  2. Conv2d, io_parallel, Resource fails badly with binary inputs--the problem has not been debugged carefully.

If this is accepted, those to-dos should be moved to an issue.

jmitrevs commented 1 year ago

This fixes a few things:

  1. cast is called with the correct mult_config in CNN cases in all cases. Previously it was correct in many cases but not all.
  2. cast was fixed in the quartus case since you cannot cast ac_fixed to ac_int without calling .to_ac_int(). Maybe one needs to add more templates here depending on whether the input types are ac_fixed or ac_int to cover all the options. (I understand ac_int to ac_fixed casting is allowed.) It's worth looking in more detail since we probably do not cover all cases.
  3. The binary versions of batchnorm did not make use of n_scale_bias, which is not the same as n_in when n_filt != -1. They have been updated to match the regular batchnorm. (This was also coped to quartus, where it was missing everywhere.)