InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.11k stars 228 forks source link

Request ONNX export support for Cheng2020 model #296

Open ZhangYuef opened 1 month ago

ZhangYuef commented 1 month ago

Feature

Support exporting Cheng2020 model to onnx format.

Motivation

To deploy the model on various hardwares.

Additional context

This is my convertion code:

import torch
from compressai.zoo import models

# net = models["bmshj2018-factorized"](quality=1, metric="mse", pretrained=True)
# net = cheng2020_anchor(quality=5, pretrained=True).to(device)
net = models["cheng2020-anchor"](quality=1, metric="mse", pretrained=True)

# Some dummy input
x = torch.randn(1, 3, 224, 224, requires_grad=True)

# Export the model
torch.onnx.export(net,                       # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "cheng2020.onnx",              # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input': {0 : 'batch_size'},    # variable length axes
                                'output': {0 : 'batch_size'}}
                 )

Error occurs:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-48-bddb317a9b45>](https://localhost:8080/#) in <cell line: 12>()
     10 
     11 # Export the model
---> 12 torch.onnx.export(net,                       # model being run
     13                   x,                         # model input (or a tuple for multiple inputs)
     14                   "cheng2020.onnx",              # where to save the model (can be a file or file-like object)

15 frames
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
    514     """
    515 
--> 516     _export(
    517         model,
    518         args,

[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
   1610             _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
   1611 
-> 1612             graph, params_dict, torch_out = _model_to_graph(
   1613                 model,
   1614                 args,

[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1132 
   1133     model = _pre_trace_quant_model(model, args)
-> 1134     graph, params, torch_out, module = _create_jit_graph(model, args)
   1135     params_dict = _get_named_param_dict(graph, params)
   1136 

[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _create_jit_graph(model, args)
   1008         return graph, params, torch_out, None
   1009 
-> 1010     graph, torch_out = _trace_and_get_graph_from_model(model, args)
   1011     _C._jit_pass_onnx_lint(graph)
   1012     state_dict = torch.jit._unique_state_dict(model)

[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _trace_and_get_graph_from_model(model, args)
    912     prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
    913     torch.set_autocast_cache_enabled(False)
--> 914     trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
    915         model,
    916         args,

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    449             prior = set_eval_frame(callback)
    450             try:
--> 451                 return fn(*args, **kwargs)
    452             finally:
    453                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs)
     34     @functools.wraps(fn)
     35     def inner(*args, **kwargs):
---> 36         return fn(*args, **kwargs)
     37 
     38     return inner

[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1308     if not isinstance(args, tuple):
   1309         args = (args,)
-> 1310     outs = ONNXTracedModule(
   1311         f, strict, _force_outplace, return_inputs, _return_inputs_states
   1312     )(*args, **kwargs)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in forward(self, *args)
    136                 return tuple(out_vars)
    137 
--> 138         graph, out = torch._C._create_graph_by_tracing(
    139             wrapper,
    140             in_vars + module_state,

[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in wrapper(*args)
    127             if self._return_inputs_states:
    128                 inputs_states.append(_unflatten(in_args, in_desc))
--> 129             outs.append(self.inner(*trace_inputs))
    130             if self._return_inputs_states:
    131                 inputs_states[0] = (inputs_states[0], trace_inputs)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _slow_forward(self, *input, **kwargs)
   1520                 recording_scopes = False
   1521         try:
-> 1522             result = self.forward(*input, **kwargs)
   1523         finally:
   1524             if recording_scopes:

[/usr/local/lib/python3.10/dist-packages/compressai/models/google.py](https://localhost:8080/#) in forward(self, x)
    543         ctx_params = self.context_prediction(y_hat)
    544         gaussian_params = self.entropy_parameters(
--> 545             torch.cat((params, ctx_params), dim=1)
    546         )
    547         scales_hat, means_hat = gaussian_params.chunk(2, 1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 14 for tensor number 1 in the list.
ZhangYuef commented 1 month ago

Related ISSUE #87.