ENOT-AutoDL / onnx2torch

Convert ONNX models to PyTorch.
Apache License 2.0
588 stars 69 forks source link

Support for exporting traced graph? #225

Open KyleErnewein opened 4 weeks ago

KyleErnewein commented 4 weeks ago

Hi, I have a usecase that ingests ONNX models, and needs to convert to PyTorch then export a traced graph (via torch.export.export()).

After converting ONNX to torch (via onnx2torch.convert()), I'm running into issues tracing the graph, due to dynamic flow control in the converted torch model.

Is there any plan for onnx2torch to support this type of usecase? Or are there any recommendations for how to workaround the dynamic flow control in the converted torch model?

An example of a problematic op is reshape - the converted torch model has logic that is conditional on the input shape parameter, to replicate ONNX's special handling of shape dimensions that have value of 0 (meaning use input shape for that dim).

Here's code to reproduce that issue:

import io
import onnx
import onnx2torch
import torch

# Create ONNX model containing a single reshape node:
class M(torch.nn.Module):
    def forward(self, x):
        x = x.reshape(20, 10)
        return x

torch_args = (torch.rand(10, 20),)
with io.BytesIO() as tmp_file:
    torch.onnx.export(model=M(), args=torch_args, f=tmp_file)
    onnx_model = onnx.load_from_string(tmp_file.getvalue())

# convert onnx --> torch
converted_torch = onnx2torch.convert(onnx_model)

# export traced graph (ExportedProgram):
ep = torch.export.export(converted_torch, args=torch_args)

This raises the following error (snippet - actual trace is very long):

...
UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "<eval_with_key>.1", line 6, in forward
    reshape = self.Reshape(input_1, constant);  input_1 = constant = None
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 36, in forward
    return _forward()
  File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 31, in _forward
    return self._do_reshape(input_tensor, shape)
  File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 20, in _do_reshape
    if torch.any(shape == 0):

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Any insights would be appreciated. Thanks!

KyleErnewein commented 4 weeks ago

For reference, this is my python env (output of pip freeze):

asttokens==2.4.1
backcall==0.2.0
decorator==5.1.1
executing==2.0.1
filelock==3.15.4
fsspec==2024.6.1
ipython==8.12.3
jedi==0.19.1
Jinja2==3.1.4
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
mpmath==1.3.0
networkx==3.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.20
nvidia-nvtx-cu12==12.1.105
onnx==1.16.2
onnx2torch==1.5.15
parso==0.8.4
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.4.0
prompt_toolkit==3.0.47
protobuf==5.27.3
ptyprocess==0.7.0
pure_eval==0.2.3
Pygments==2.18.0
six==1.16.0
stack-data==0.6.3
sympy==1.13.2
torch==2.4.0
torchvision==0.19.0
traitlets==5.14.3
triton==3.0.0
typing_extensions==4.12.2
wcwidth==0.2.13