fastmachinelearning / hls4ml

Machine learning on FPGAs using HLS
https://fastmachinelearning.org/hls4ml
Apache License 2.0
1.18k stars 388 forks source link

Improved parsing of pytorch models using torch.FX - Clean #799

Closed JanFSchulte closed 1 year ago

JanFSchulte commented 1 year ago

Refreshed version of https://github.com/fastmachinelearning/hls4ml/pull/723 to leave behind messy git history

Current parsing of pytorch models uses a loop of the named_modules of the model (https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/converters/pytorch_to_hls.py#L163). This has several disadvantages:

In this PR, we propose to fix this by first created a graph representation of the model's forward() function using the symbolic tracing functionality of https://pytorch.org/docs/stable/fx.html. Each operation in the forward() is represented by a node in the graph. Nodes can be of these types: image

For example, for this model

class MyModuleConvRelu(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3,3,3)

    def forward(self, x):
        y1 = self.conv(x)
        y = torch.relu(y1)
        y = y + y1
        y = torch.relu(y)
        return y

the resulting graph representation is

graph():
    %x : [#users=1] = placeholder[target=x]
    %conv : [#users=2] = call_module[target=conv](args = (%x,), kwargs = {})
    %relu : [#users=1] = call_function[target=torch.relu](args = (%conv,), kwargs = {})
    %add : [#users=1] = call_function[target=operator.add](args = (%relu, %conv), kwargs = {})
    %relu_1 : [#users=1] = call_function[target=torch.relu](args = (%add,), kwargs = {})
    return relu_1

As the nodes in the graph follow the order of operations of the forward() function, we can then simply loop over them and parse each node into one node in the hls4ml model representation. For the parsing of the individual layers, existing code is used where available without significant changes. Functionality for more types of layers is also added by this PR.

The types of layers currently understood by the parser are

This PR also fixes https://github.com/fastmachinelearning/hls4ml/issues/409

Changes are mostly confined to the frontend, but small changes are made to the backend to the templates for pooling layers to add the option that zero-padded entries are included in average pooling operations.

One big difference between pytorch and keras is the data format of the input tensors, which is channels_first by default, instead of the channels_last used by keras. The built-in tools in pytorch to convert a model to channels_last don't work for all dimensions of the input. Therefore the functionality has been added to transpose the inputs within hls4ml so the existing channels_last implementations of layers can be used. By default the inputs are transposed for io_parrallel but not io_stream since we don't have transpose layers for all dimensions in io_stream. The outputs are not transposed by default, but this can be switched on, again only for io_parallel.

Limitations:

Type of change

Tests

The new parsing was tested using 5-6 different pytorch model examples from around the web. In addition, I verified that the two example models for pytroch included with hls4ml get parsed successfully. A test for the API was added in the test/pytest folder, in analogy to the test for the keras parser. All tests pass successfully.

Checklist

vloncar commented 1 year ago

This has reached a reached a very high level of stability and feature-set, it is far better than what is currently in the main branch. Support for some ops is not complete and there are some guards to be added to ensure proper parsing, but these are mostly corner cases that we can address later. So I would propose we merge it in this state and continue with bugfixes as we go. There's significant developments built on top already that I wouldn't like to push as part of this PR.