import torch
from compressai.zoo import models
# net = models["bmshj2018-factorized"](quality=1, metric="mse", pretrained=True)
# net = cheng2020_anchor(quality=5, pretrained=True).to(device)
net = models["cheng2020-anchor"](quality=1, metric="mse", pretrained=True)
# Some dummy input
x = torch.randn(1, 3, 224, 224, requires_grad=True)
# Export the model
torch.onnx.export(net, # model being run
x, # model input (or a tuple for multiple inputs)
"cheng2020.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input': {0 : 'batch_size'}, # variable length axes
'output': {0 : 'batch_size'}}
)
Error occurs:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-48-bddb317a9b45>](https://localhost:8080/#) in <cell line: 12>()
10
11 # Export the model
---> 12 torch.onnx.export(net, # model being run
13 x, # model input (or a tuple for multiple inputs)
14 "cheng2020.onnx", # where to save the model (can be a file or file-like object)
15 frames
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
514 """
515
--> 516 _export(
517 model,
518 args,
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
1610 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
1611
-> 1612 graph, params_dict, torch_out = _model_to_graph(
1613 model,
1614 args,
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
1132
1133 model = _pre_trace_quant_model(model, args)
-> 1134 graph, params, torch_out, module = _create_jit_graph(model, args)
1135 params_dict = _get_named_param_dict(graph, params)
1136
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _create_jit_graph(model, args)
1008 return graph, params, torch_out, None
1009
-> 1010 graph, torch_out = _trace_and_get_graph_from_model(model, args)
1011 _C._jit_pass_onnx_lint(graph)
1012 state_dict = torch.jit._unique_state_dict(model)
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _trace_and_get_graph_from_model(model, args)
912 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
913 torch.set_autocast_cache_enabled(False)
--> 914 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
915 model,
916 args,
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
449 prior = set_eval_frame(callback)
450 try:
--> 451 return fn(*args, **kwargs)
452 finally:
453 set_eval_frame(prior)
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs)
34 @functools.wraps(fn)
35 def inner(*args, **kwargs):
---> 36 return fn(*args, **kwargs)
37
38 return inner
[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1308 if not isinstance(args, tuple):
1309 args = (args,)
-> 1310 outs = ONNXTracedModule(
1311 f, strict, _force_outplace, return_inputs, _return_inputs_states
1312 )(*args, **kwargs)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in forward(self, *args)
136 return tuple(out_vars)
137
--> 138 graph, out = torch._C._create_graph_by_tracing(
139 wrapper,
140 in_vars + module_state,
[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in wrapper(*args)
127 if self._return_inputs_states:
128 inputs_states.append(_unflatten(in_args, in_desc))
--> 129 outs.append(self.inner(*trace_inputs))
130 if self._return_inputs_states:
131 inputs_states[0] = (inputs_states[0], trace_inputs)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _slow_forward(self, *input, **kwargs)
1520 recording_scopes = False
1521 try:
-> 1522 result = self.forward(*input, **kwargs)
1523 finally:
1524 if recording_scopes:
[/usr/local/lib/python3.10/dist-packages/compressai/models/google.py](https://localhost:8080/#) in forward(self, x)
543 ctx_params = self.context_prediction(y_hat)
544 gaussian_params = self.entropy_parameters(
--> 545 torch.cat((params, ctx_params), dim=1)
546 )
547 scales_hat, means_hat = gaussian_params.chunk(2, 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 14 for tensor number 1 in the list.
Feature
Support exporting Cheng2020 model to onnx format.
Motivation
To deploy the model on various hardwares.
Additional context
This is my convertion code:
Error occurs: