fastmachinelearning / hls4ml

Machine learning on FPGAs using HLS
https://fastmachinelearning.org/hls4ml
Apache License 2.0
1.17k stars 388 forks source link

Improve parsing of non-nn.Sequential PyTorch models #840

Closed vloncar closed 10 months ago

vloncar commented 11 months ago

Description

In case of skipped layers, like Flatten or Dropout, PyTorch converter will incorrectly parse the model inputs, we need to create an input map similar to how Keras handles it. This was the case in #839. Additionally, as observed in #838, parsing of BN weights was broken. These fixes are cherrypicked from my development branch for parsing GNNs, not fully tested standalone, so I'm making this a draft PR for now before I add proper tests.

Type of change

Tests

Currently lacking. Will add something along the lines of code shared in #838 and #839

Checklist

jmitrevs commented 11 months ago

Note, Flatten is not a skip layer for Keras, but gets turned into a reshape. We should check why we made that decision there and a different decision here.

jmitrevs commented 11 months ago

Concerning Flatten always disappearing, whether io_stream or io_parallel, I think effectively the optimizers do that now, but I think we were worried that this isn't guaranteed to always be the case. More qualifications can be added in the streaming case if this isn't true, or if we have a different backend. I am not sure what's best but in general I think we should handle Keras and Pytorch the same way, unless there's a good reason for this to not be the case.

vloncar commented 10 months ago

Continued in #848. Closing.