fastmachinelearning / hls4ml

Machine learning on FPGAs using HLS
https://fastmachinelearning.org/hls4ml
Apache License 2.0
1.19k stars 390 forks source link

Fix pooling layers when padding is applied from the left/top #757

Closed JanFSchulte closed 1 year ago

JanFSchulte commented 1 year ago

When padding is applied from the left (or top and left in the 2D case), the indices to the entries of the data array currently point to the wrong elements when the pools are filled because the shift in indices due to the padding is not taken into account. In Keras this only happens when the pool_size is set above 2 so that the padding in the case of same padding is not only applied to the right but also to the left side of the input.

This PR has a simple fix by correctly shifting the indices to take into account the offset. Is transparent in case there is no padding from the left/top required.

Additionally, when the zero padding is removed from the result in case of AveragePooling, this was not done correctly because the division of 2 integers always resulted in an integer result, so that the result was always rescaled with a factor of 1. This has been fixed by casting the numerator to data_T.

Type of change

For a new feature or function, please create an issue first to discuss it with us before submitting a pull request.

Note: Please delete options that are not relevant.

Tests

Problem can be reproduced and fix verified with this small script:

import math
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import MaxPooling1D

import hls4ml

input_shape = (124, 5)
X_input = np.random.rand(100, *input_shape)

keras_model = tf.keras.models.Sequential()
keras_model.add(MaxPooling1D(pool_size = 3, padding="same", input_shape=input_shape))
keras_model.compile()

hls_cfg = hls4ml.utils.config_from_keras_model(keras_model)
output_dir = "test_keras"

hls_model = hls4ml.converters.convert_from_keras_model(
    keras_model, hls_config=hls_cfg, output_dir=output_dir, backend="Vivado"
)
hls_model.compile()

# Verify accuracy
keras_prediction = keras_model.predict(X_input)
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)

np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=3e-2)

A test was added in test/pytest/test_pooling.h to verify the fixes.

Checklist

JanFSchulte commented 1 year ago

To me it looks like this test failure is unrelated to the changes in this PR. It looks like tensorflow was not correctly set up:

File "/usr/local/lib/python3.10/site-packages/qkeras/__init__.py", line 20, in <module>
    import tensorflow as tf
ModuleNotFoundError: No module named 'tensorflow'
jmduarte commented 1 year ago

Hm it seems tensorflow is not a requirement for qkeras? I always thought it was: https://github.com/google/qkeras/blob/v0.9.0/setup.py#L41-L50

So I think this means we should explicitly add tensorflow as a requirement to hls4ml @vloncar @jmitrevs

vloncar commented 1 year ago

@JanFSchulte Can you add a test case (test_pooling.py) that checks for this? You can use the existing test for global pooling as a starting point, just chenge a few things. Test it with all three backends and both io_parallel and io_stream. You may find the results interesting :wink:

vloncar commented 1 year ago

So I think this means we should explicitly add tensorflow as a requirement to hls4ml @vloncar @jmitrevs

I think we didn't do it was because in old times we were thinking that we would want to support having pytorch and onnx converters if TF wasn't installed and the other way around. This was never fully implemented, but perhaps it is not a nonsense idea. I'm surprised that qkeras and its dependencies (tfmot and keras-tuner) don't explicitly require tensorflow. Maybe there's some reason for that? I also see that in the hls4ml-tutorial environment we install tensorflow explicitly. Can we do the same for sphinx environment? Or we revert the cleanup change that caused it.

JanFSchulte commented 1 year ago

I have added the fixes to Vitis and also fixed another bug in the case of average pooling. Tests are added, but only for the case of io_parallel, since padding is not supported in case of io_stream.