InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.2k stars 232 forks source link

Support for ONNX export #87

Closed lbhm closed 1 year ago

lbhm commented 3 years ago

Feature

Enable CompressAI models to be exportable to the ONNX format.

Motivation

I would like to use some of the CompressAI models in a third-party inference framework which allows models to be imported from ONNX files. However, the models currently do not support ONNX export in my tests.

Therefore, I'd like to ask: Is it generally possible to rewrite the CompressAI models to support ONNX export? I just started reading into the ONNX standard so my understanding might be incomplete. Possible issues that I came up with so far are:

Any help/feedback is appreciated!

Additional context

What I tried so far:

import torch
from compressai.zoo import models

net = models["bmshj2018-factorized"](quality=1, metric="mse", pretrained=True)
# Some dummy input
x = torch.randn(1, 3, 224, 224, requires_grad=True)

# Export the model
torch.onnx.export(net,                       # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",              # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input': {0 : 'batch_size'},    # variable length axes
                                'output': {0 : 'batch_size'}}
                 )

The above code fails with

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_15137/1585788810.py in <module>
      1 # Export the model
----> 2 torch.onnx.export(net,                  # model being run
      3                   x,                         # model input (or a tuple for multiple inputs)
      4                   "bmshj2018-factorized.onnx",           # where to save the model (can be a file or file-like object)
      5                   export_params=True,        # store the trained parameter weights inside the model file

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    273 
    274     from torch.onnx import utils
--> 275     return utils.export(model, args, f, export_params, verbose, training,
    276                         input_names, output_names, aten, export_raw_ir,
    277                         operator_export_type, opset_version, _retain_param_name,

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     86         else:
     87             operator_export_type = OperatorExportTypes.ONNX
---> 88     _export(model, args, f, export_params, verbose, training, input_names, output_names,
     89             operator_export_type=operator_export_type, opset_version=opset_version,
     90             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference)
    687 
    688             graph, params_dict, torch_out = \
--> 689                 _model_to_graph(model, args, verbose, input_names,
    690                                 output_names, operator_export_type,
    691                                 example_outputs, _retain_param_name,

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
    461     params_dict = _get_named_param_dict(graph, params)
    462 
--> 463     graph = _optimize_graph(graph, operator_export_type,
    464                             _disable_torch_constant_prop=_disable_torch_constant_prop,
    465                             fixed_batch_size=fixed_batch_size, params_dict=params_dict,

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    198             dynamic_axes = {} if dynamic_axes is None else dynamic_axes
    199             torch._C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
--> 200         graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    201         torch._C._jit_pass_lint(graph)
    202 

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs)
    311 def _run_symbolic_function(*args, **kwargs):
    312     from torch.onnx import utils
--> 313     return utils._run_symbolic_function(*args, **kwargs)
    314 
    315 

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/utils.py in _run_symbolic_function(g, block, n, inputs, env, operator_export_type)
    992                     return None
    993                 attrs = {k: n[k] for k in n.attributeNames()}
--> 994                 return symbolic_fn(g, *inputs, **attrs)
    995 
    996         elif ns == "prim":

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py in wrapper(g, *args, **kwargs)
    170             if len(kwargs) == 1:
    171                 assert '_outputs' in kwargs
--> 172             return fn(g, *args, **kwargs)
    173 
    174         return wrapper

~/Projects/tensorrt_test/venv/lib/python3.9/site-packages/torch/onnx/symbolic_opset9.py in _convolution(g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
   1271 
   1272     if kernel_shape is None or any([i is None for i in kernel_shape]):
-> 1273         raise RuntimeError('Unsupported: ONNX export of convolution for kernel '
   1274                            'of unknown shape.')
   1275 

RuntimeError: Unsupported: ONNX export of convolution for kernel of unknown shape.
vjsrinivas commented 2 years ago

Has there been any updates regarding ONNX support?

fracape commented 1 year ago

Using ONNX in our conversions to SADL, but only for the factorizedprior model. Please check this readme

ZhangYuef commented 5 months ago

"bmshj2018-factorized" can be directly exported with PyTorch:

import torch
from compressai.zoo import models

net = models["bmshj2018-factorized"](quality=1, metric="mse", pretrained=True)
# net = cheng2020_anchor(quality=5, pretrained=True).to(device)

# Some dummy input
x = torch.randn(1, 3, 224, 224, requires_grad=True)

# Export the model
torch.onnx.export(net,                       # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",              # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input': {0 : 'batch_size'},    # variable length axes
                                'output': {0 : 'batch_size'}}
                 )

onnx_model = onnx.load("model.onnx")
onnx_model_graph = onnx_model.graph
onnx_session = onnxruntime.InferenceSession(onnx_model.SerializeToString())
# onnx_session = onnxruntime.InferenceSession("cheng2020.onnx")

input_shape = (1, 3, 224, 224)
x = torch.randn(input_shape).numpy()

input_names = ["input"]
output_names = ["output"]

onnx_output = onnx_session.run(output_names, {input_names[0]: x})[0]

But error occurs when exporting Cheng2020 model. Error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-48-bddb317a9b45>](https://localhost:8080/#) in <cell line: 12>()
     10 
     11 # Export the model
---> 12 torch.onnx.export(net,                       # model being run
     13                   x,                         # model input (or a tuple for multiple inputs)
     14                   "cheng2020.onnx",              # where to save the model (can be a file or file-like object)

15 frames
[/usr/local/lib/python3.10/dist-packages/compressai/models/google.py](https://localhost:8080/#) in forward(self, x)
    543         ctx_params = self.context_prediction(y_hat)
    544         gaussian_params = self.entropy_parameters(
--> 545             torch.cat((params, ctx_params), dim=1)
    546         )
    547         scales_hat, means_hat = gaussian_params.chunk(2, 1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 14 for tensor number 1 in the list.

@fracape @lbhm