fastmachinelearning / hls4ml

Machine learning on FPGAs using HLS
https://fastmachinelearning.org/hls4ml
Apache License 2.0
1.17k stars 388 forks source link

Add support for flattening to the pytorch parser #852

Closed JanFSchulte closed 10 months ago

JanFSchulte commented 10 months ago

Currently, Flatten layers are skipped in the pytorch parser. Additionally, this operation is not flagged as unsupported, so the model will parse, but exhibit incorrect behavior. This PR adds support for these layers by adding them to the parser. The optimizer pass that converts the operations to channels_last for pytorch models is adapted to transpose the input to the flattener so the output elements are in the correct order.

Type of change

For a new feature or function, please create an issue first to discuss it with us before submitting a pull request.

Note: Please delete options that are not relevant.

Tests

Verified that a simple model with a Conv2D and a Flatten operation give correct results, both for the torch.nn.Flatten and torch.flatten() interfaces to this operation in pytorch. Pytest has been added to verify this.

Checklist

vloncar commented 10 months ago

LGTM. I also added support for fully parsing start_dim and end_dim of nn.Flatten to our Reshape. Once tests pass, I'll merge this if there are no objections in the meantime.