✅ Obtain model graph with `torch.export.export`
❌ Translate the graph into ONNX
⚪ Run `onnx.checker` on the ONNX model
⚪ Execute the model with ONNX Runtime
⚪ Validate model output accuracy
Error message:
Traceback (most recent call last):
File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_building.py", line 453, in eval_function
return function.function(**named_inputs, **named_attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/justinc/Documents/GitHub/onnxscript/onnxscript/function_libs/torch_lib/ops/nn.py", line 1166, in aten_mse_loss
result = op.Mul(self - target, self - target)
~~~~~^~~~~~~~
TypeError: unsupported operand type(s) for -: 'Input' and 'Input'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_core.py", line 398, in _handle_call_function_node_with_lowering
outputs = onnx_function(*onnx_args, **onnx_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/justinc/Documents/GitHub/onnxscript/onnxscript/values.py", line 528, in __call__
return evaluator.default().eval_function(self, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_building.py", line 462, in eval_function
raise RuntimeError(
RuntimeError: Error calling function 'aten_mse_loss' with args (Input('arg0_1', type=Tensor(FLOAT), shape=[2,3,5], producer=None, index=None), Input('arg1_1', type=Tensor(FLOAT), shape=[2,3,5], producer=None, index=None), 0) and kwargs {}.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_core.py", line 490, in _add_nodes
_handle_call_function_node_with_lowering(
File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_core.py", line 400, in _handle_call_function_node_with_lowering
raise RuntimeError(
RuntimeError: Error when calling function 'OnnxFunction(<function aten_mse_loss at 0x13f9ac180>)' with args '[Input('arg0_1', type=Tensor(FLOAT), shape=[2,3,5], producer=None, index=None), Input('arg1_1', type=Tensor(FLOAT), shape=[2,3,5], producer=None, index=None), 0]' and kwargs '{}'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_patch.py", line 222, in _torch_onnx_export
ir_model = torch_onnx.exported_program_to_ir(program)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_core.py", line 626, in exported_program_to_ir
values = _add_nodes(exported_program, model, lower=lower, registry=registry)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_core.py", line 501, in _add_nodes
raise RuntimeError(
RuntimeError: Error when translating node %mse_loss : [num_users=1] = call_function[target=torch.ops.aten.mse_loss.default](args = (%arg0_1, %arg1_1, 0), kwargs = {}). See the stack trace for more information.
PyTorch ONNX Conversion Error Report
Error message:
Exported program:
Analysis
PyTorch ONNX Conversion Analysis
Model Information
The model has 0 parameters and 0 buffers (non-trainable parameters). Number of parameters per dtype:
Number of buffers per dtype:
Inputs:
arg0_1
:TensorMetadata(shape=torch.Size([2, 3, 5]), dtype=torch.float32, requires_grad=False, stride=(15, 5, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
arg1_1
:TensorMetadata(shape=torch.Size([2, 3, 5]), dtype=torch.float32, requires_grad=False, stride=(15, 5, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
Outputs:
mse_loss
:TensorMetadata(shape=torch.Size([2, 3, 5]), dtype=torch.float32, requires_grad=False, stride=(15, 5, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
mse_loss_1
:TensorMetadata(shape=torch.Size([]), dtype=torch.float32, requires_grad=False, stride=(), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
mse_loss_2
:TensorMetadata(shape=torch.Size([]), dtype=torch.float32, requires_grad=False, stride=(), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
The FX graph has 6 nodes in total. Number of FX nodes per op:
placeholder
: 2call_function
: 3output
: 1Of the call_function nodes, the counts of operators used are:
aten.mse_loss.default
: 3ONNX Conversion Information
All operators in the model have registered ONNX decompositions.