fastmachinelearning / hls4ml

Machine learning on FPGAs using HLS
https://fastmachinelearning.org/hls4ml
Apache License 2.0
1.24k stars 401 forks source link

Incorrect hls4ml results for AveragePooling2D/MaxPooling2D Keras layer #995

Open clw5710 opened 5 months ago

clw5710 commented 5 months ago

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

Quick summary

The AveragePooling2D Keras Layer is producing incorrect data when compiled with hls4ml. Namely, when the pool_size=1 and strides=2 the data is incorrect. When setting the pool_size=2 the hls4ml data matches Keras. This issue was observed by turning on tracing and saving off the Keras and hls4ml data. The layer data matches between Keras and hls4ml until the AveragePooling2D layer is reached.

Details

I have created a simple 6 layer model to isolate the problem with AveragePooling2D. I've also ran profiling to confirm that the bit widths are set correctly. I do not see the values saturating in the trace results and the bit widths have been widened to 64,6 just to be safe. There does not seem to be some kind of pattern between hls4ml and Keras for the incorrect layer data, and it seems to occur when the pool_size is smaller than the strides.

Steps to Reproduce

The code can be found on this gist: https://gist.github.com/clw5710/cd3ba4dfd23db151d1a6fbf10db2ec30 The model can be trained and compiled with hls4ml very quickly. The file nn_pooling_test.py contains the model architecture, then there are two jupyter notebooks that train the model and build it using hls4ml. The training accuracy is poor but I'm just interested in the hls4ml AveragePooling2D layer results matching Keras.

I also supplied the csv output files where you can see the first convolutional layer matches but then the pooling layer does not. If the pool_size is set to 2 in nn_pooling_test.py the results match.

Expected behavior

I expected the hls4ml AveragePooling2D data to match the Keras AveragePooling2D data.

Actual behavior

The data is incorrect when the pool_size=1 and strides=2. When pool_size=2 and strides=2 the AveragePool2D data matches between Keras and hls4ml. This edge case seems to be causing an issue.

Possible fix

I believe the fix would be in nnet_pooling.h in the pooling2d_cl function here: https://github.com/fastmachinelearning/hls4ml/blob/c44fbc977591d913188c3ff5d789a3d29dfbb256/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h#L200