fastmachinelearning / hls4ml

Machine learning on FPGAs using HLS
https://fastmachinelearning.org/hls4ml
Apache License 2.0
1.17k stars 388 forks source link

Add RNN support for Pytorch #850

Open JanFSchulte opened 10 months ago

JanFSchulte commented 10 months ago

Adds support for RNN layers (GRU, LSTM, RNN) to the pytorch parser.

Caveat: We currently lack implementation for getitem operations, so we can currently not return the hidden state after the calculations

Caveat 2: We currently only support a single recurrent layers, whereas multiple within the same RNN instance are supported by pytorch

Caveat 3: We currently don't support the passing of non-zero initial values for the hidden states to the RNN

So this implementation is slightly hacky at the moment, but might serve as a starting point for discussion, and can be used by interested parties if they can life with the current limitations.

Also, this contains parts of https://github.com/fastmachinelearning/hls4ml/pull/848 because I was inattentive.

Type of change

For a new feature or function, please create an issue first to discuss it with us before submitting a pull request.

Note: Please delete options that are not relevant.

Tests

Added pytests to confirm that the layers work.

Checklist

vloncar commented 10 months ago

pre-commit.ci autofix

jmitrevs commented 8 months ago

The tests fail with:

FAILED test_pytorch_api.py::test_skipped_layers[io_parallel-Vivado] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_parallel-Quartus] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_stream-Vivado] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_stream-Quartus] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
JanFSchulte commented 2 months ago

All test failures the last time around seemed to be related to issues with the tests themselves, which I have mostly fixed. The only change I made was to add missing includes to some Quartus templates to fix compiliation errors when uint_8 was used.

There are currently still some remaining test failures with the case when activations are used in their nn.functionals implementation instead of as classes. Here I can't reproduce the failures in a standalone file, the exact same code that fails in the pytest works fine running in standalone python. Have not figured out how to debug it in those circumstances.