Closed luisfmnunes closed 9 months ago
Thanks for reporting this!
The decomposition for all aten operators are under https://github.com/microsoft/onnxscript/tree/main/onnxscript/function_libs/torch_lib/ops
For aten.roll
, we need a variation of https://github.com/microsoft/onnxscript/blob/b7f215ea130e455bd0cfe999d551389bc0718489/onnxscript/function_libs/torch_lib/ops/core.py#L6897 to handle complex inputs.
For aten.var.correction
we need to implement https://github.com/microsoft/onnxscript/blob/b7f215ea130e455bd0cfe999d551389bc0718489/onnxscript/function_libs/torch_lib/ops/core.py#L8261
Please follow this guide https://github.com/microsoft/onnxscript/wiki/TorchLib-function-authoring-guide if you would like to contribute. Thank you!
@justinchuby, Thank you for all provided material.
The implementation from aten::roll
for complex type seems to have worked fine because all tests seems to have passed (I basically copied the structure from the real one and applied the roll to each channel [real and imag] individually and then concatenated).
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_complex_output_match_opinfo__roll_cpu_complex64 PASSED [ 7%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_bool PASSED [ 14%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_int64 PASSED [ 21%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_int32 PASSED [ 28%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_float16 PASSED [ 35%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_float32 PASSED [ 42%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_complex_output_match_opinfo__roll_cpu_complex64 PASSED [ 50%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_int64 PASSED [ 57%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_bool PASSED [ 64%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_int32 PASSED [ 71%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_float16 PASSED [ 78%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_float32 PASSED [ 85%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestFunctionValidity::test_function_has_op_schema_315_aten_roll PASSED [ 92%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestFunctionValidity::test_function_has_op_schema_316_aten_roll_complex PASSED
I can't say the same about the aten::var
though. I basically used the same logic applied to aten::var_mean
and got the same skipped and xfail tests, like onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__var_unbiased_cpu_float32 SKIPPED (Skip: fixme: Inferred shape and existing shape differ in rank)
. It seems like the new scope of torch.var is always being called, ignoring old (unbiased) function scope. It is a strange behaviour because the following test seems to have called the new scope, giving True
from unbiased
to dims
(which is implicitly converted to 1) and for some reason the correction receives a value of 5 (coincidence with the input shape?).
Using numpy, with the same input I obtained the same Actual Output values using a correction of 5.0 .
Now it succeds when converting to ONNX with the following script:
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))
# model.load_state_dict(state_dict["model"])
model.eval()
logger.info(f"Exporting model to {onnx_model}")
# export(
# model,
# torch.randn(1, 1, 512, 512),
# onnx_model,
# export_params=True,
# do_constant_folding=True,
# opset_version=18,
# input_names=["image"],
# output_names=["center", "grid", "pose_2d", "seg", "img_sup", "seg_sup"],
# dynamic_axes={
# 'image':{0: "batch_size", 2: "height", 3: "width"},
# 'center': {0: "batch_size"},
# "grid": {0: "batch_size"},
# "pose_2d": {0: "batch_size"},
# "seg": {0: "batch_size"},
# "img_sup": {0: "batch_size"},
# "seg_sup": {0: "batch_size"}
# }
# )
# print(torch.export.export(model, (torch.randn(1, 1, 512, 512, dtype=torch.float32),)))
# TorchDynamo exports correctly but there are still unsupported onnxscript ops.
# Waiting Issue Response to solve this.
onnx_prog = torch.onnx.dynamo_export(model, torch.randn(1, 1, 512, 512, dtype=torch.float32))
onnx_prog.save(onnx_model.as_posix(), model_state_dict=state_dict["model"])
onnx_model = onnx.load(onnx_model)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))
But I get the following Warnings:
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:130: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
warnings.warn(
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/onnx/_internal/fx/passes/readability.py:53: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
new_node = self.module.graph.get_attr(normalized_name)
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer0_1_running_mean target layer0/1/running_mean layer0/1/running_mean of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer0_1_running_var target layer0/1/running_var layer0/1/running_var of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_0_bn1_running_mean target layer1/1/0/bn1/running_mean layer1/1/0/bn1/running_mean of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_0_bn1_running_var target layer1/1/0/bn1/running_var layer1/1/0/bn1/running_var of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_0_bn2_running_mean target layer1/1/0/bn2/running_mean layer1/1/0/bn2/running_mean of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_0_bn2_running_var target layer1/1/0/bn2/running_var layer1/1/0/bn2/running_var of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_1_bn1_running_mean target layer1/1/1/bn1/running_mean layer1/1/1/bn1/running_mean of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_1_bn1_running_var target layer1/1/1/bn1/running_var layer1/1/1/bn1/running_var of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_1_bn2_running_mean target layer1/1/1/bn2/running_mean layer1/1/1/bn2/running_mean of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_1_bn2_running_var target layer1/1/1/bn2/running_var layer1/1/1/bn2/running_var of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer2_0_bn1_running_mean target layer2/0/bn1/running_mean layer2/0/bn1/running_mean of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer2_0_bn1_running_var target layer2/0/bn1/running_var layer2/0/bn1/running_var of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer2_0_bn2_running_mean target layer2/0/bn2/running_mean layer2/0/bn2/running_mean of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
# Many repeated Warnings with the same signature.
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node decoder_layer2_conv_4_running_var target decoder/layer2/conv/4/running_var decoder/layer2/conv/4/running_var of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/Melo/onnxscript/onnxscript/function_libs/torch_lib/graph_building.py:971: UserWarning: ONNX model is invalid: [ShapeInferenceError] (op_type:Add, node name: Add_56): [TypeInferenceError] Inferred elem type differs from existing elem type: (7) vs (1)
warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
and when I load the model with onnx.load
and use onnx.checker.check_model
, it raises the following exception:
Traceback (most recent call last):
File "/home/griaule/Fingerprint-2DPose-Dense-Voting/model2onnx.py", line 88, in <module>
main(parse_args())
File "/home/griaule/Fingerprint-2DPose-Dense-Voting/model2onnx.py", line 75, in main
onnx.checker.check_model(onnx_model)
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnx/checker.py", line 148, in check_model
C.check_model(protobuf_string, full_check, skip_opset_compatibility_check)
onnx.onnx_cpp2py_export.checker.ValidationError: preprocess_tv.img_grad.weight_x initializer name is not unique
Yet, when I try to load it as an inference model for onnxruntime, I get the following exception:
>>> import onnxruntime as ort
>>> ort.InferenceSession("out/20231119_175136/best.onnx")
2023-11-27 17:05:35.103468663 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'preprocess_tv.img_grad.weight_x' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103490273 [W:onnxruntime:, graph.cc:1283 Graph] Initializer preprocess_tv.img_grad.weight_x appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103494888 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'preprocess_tv.img_grad.weight_y' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103498240 [W:onnxruntime:, graph.cc:1283 Graph] Initializer preprocess_tv.img_grad.weight_y appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103501922 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'input_layer.1.weight_avg' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103505156 [W:onnxruntime:, graph.cc:1283 Graph] Initializer input_layer.1.weight_avg appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103508790 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'input_layer.1.conv_grad.weight_x' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103511971 [W:onnxruntime:, graph.cc:1283 Graph] Initializer input_layer.1.conv_grad.weight_x appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103515655 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'input_layer.1.conv_grad.weight_y' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103518761 [W:onnxruntime:, graph.cc:1283 Graph] Initializer input_layer.1.conv_grad.weight_y appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103522334 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'input_layer.1.conv_gaussian.gkern2d' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103525487 [W:onnxruntime:, graph.cc:1283 Graph] Initializer input_layer.1.conv_gaussian.gkern2d appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
# Many Repeated warnings for different Parameters
2023-11-27 17:05:35.105126349 [W:onnxruntime:, graph.cc:1283 Graph] Initializer pixels_out.weight appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.105131066 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'pixels_out.bias' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.105135000 [W:onnxruntime:, graph.cc:1283 Graph] Initializer pixels_out.bias appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from out/20231119_175136/best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_0) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_15): B has inconsistent type tensor(int64
I also used the onnx.helper.printable_graph
tool to export the graph structure of the ONNX protobuf. The result is in onnx_graph.txt
Sorry for the long response. I'm kind lost on how to proceed. Thank you.
Thanks for doing the experiments! Let's solve this in a few steps.
aten::var
, if you could create a pull request with your implementation and tests, we can look together to make sure it is correct.onnxscript
from the github main branch. If not this may be a bug in our implementation. I will look into this.I also used ONNX GraphSurgeon to check on types, seems like all the _val_
are set to None
or 0
as dtype
(not sure if this is expected during conversion and what might cause it). Here(onnx_graph_data.txt) is the output I logged from all nodes of the graph.
can you share the onnx model itself? You may zip it and attach it here.
Thanks for doing the experiments! Let's solve this in a few steps.
- for
aten::var
, if you could create a pull request with your implementation and tests, we can look together to make sure it is correct.- The warnings are from PyTorch dynamo. I would ignore them until we solve the rest and still aren't able to get the model correct
- The model check error may be related to Fix value_info names in symbolic shape export #1184? Try installing
onnxscript
from the github main branch. If not this may be a bug in our implementation. I will look into this.- We can care about the ORT type error later. This may be a type promotion cases we are not handling correctly in the converter, but it could also be due to other things that can be solved with (1)
Ok, so I was able to get rid of the duplicated parameters warnings and the onnx.checker.check_model
exception by removing the state_load from onnx_prog.save() and loading the parameters previously with model.load_state_dict. The code bellow works for model conversion.
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))
model.load_state_dict(state_dict["model"])
model.eval()
logger.info(f"Exporting model to {onnx_model}")
# export(
# model,
# torch.randn(1, 1, 512, 512),
# onnx_model,
# export_params=True,
# do_constant_folding=True,
# opset_version=18,
# input_names=["image"],
# output_names=["center", "grid", "pose_2d", "seg", "img_sup", "seg_sup"],
# dynamic_axes={
# 'image':{0: "batch_size", 2: "height", 3: "width"},
# 'center': {0: "batch_size"},
# "grid": {0: "batch_size"},
# "pose_2d": {0: "batch_size"},
# "seg": {0: "batch_size"},
# "img_sup": {0: "batch_size"},
# "seg_sup": {0: "batch_size"}
# }
# )
# print(torch.export.export(model, (torch.randn(1, 1, 512, 512, dtype=torch.float32),)))
# TorchDynamo exports correctly but there are still unsupported onnxscript ops.
# Waiting Issue Response to solve this.
onnx_prog = torch.onnx.dynamo_export(model, torch.randn(1, 1, 512, 512, dtype=torch.float32))
onnx_prog.save(onnx_model.as_posix())
onnx_model = onnx.load(onnx_model)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))
The problem is that the model is still incorrect showing the warning /home/griaule/Melo/onnxscript/onnxscript/function_libs/torch_lib/graph_building.py:971: UserWarning: ONNX model is invalid: [ShapeInferenceError] (op_type:Add, node name: Add_56): [TypeInferenceError] Inferred elem type differs from existing elem type: (7) vs (1)
on torch.onnx.dynamo_export
. When I try to use the model on ONNX Runtime, it stills raises the problem regarding inconsistent type
>>> import onnxruntime as ort
>>> ort.InferenceSession("best.onnx")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_0) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_15): B has inconsistent type tensor(int64)
Bellow is the class of the FastCartoonTexture nn.Module that seems to be crashing:
class ImageGradient(nn.Module):
def __init__(self):
super().__init__()
kernel_x = [[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]]
kernel_x = torch.FloatTensor(kernel_x).unsqueeze(0).unsqueeze(0)
kernel_y = [[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]]
kernel_y = torch.FloatTensor(kernel_y).unsqueeze(0).unsqueeze(0)
self.weight_x = nn.Parameter(data=kernel_x, requires_grad=False)
self.weight_y = nn.Parameter(data=kernel_y, requires_grad=False)
def forward(self, x):
grad_x = F.conv2d(x, self.weight_x, padding=1)
grad_y = F.conv2d(x, self.weight_y, padding=1)
return grad_x, grad_y
class FastCartoonTexture(nn.Module):
def __init__(self, sigma=2.5, eps=1e-6) -> None:
super().__init__()
self.sigma = sigma
self.eps = eps
self.cmin = 0.3
self.cmax = 0.7
self.lim = 20
self.img_grad = ImageGradient()
def lowpass_filtering(self, img, L):
img_fft = torch.fft.fftshift(torch.fft.fft2(img), dim=(-2, -1)) * L
img_rec = torch.fft.ifft2(torch.fft.fftshift(img_fft, dim=(-2, -1)))
img_rec = torch.real(img_rec)
return img_rec
def gradient_norm(self, img):
Gx, Gy = self.img_grad(img)
return torch.sqrt(Gx ** 2 + Gy ** 2) + self.eps
def forward(self, input):
H, W = input.size(-2), input.size(-1)
grid_y, grid_x = torch.meshgrid(torch.linspace(-0.5, 0.5, H), torch.linspace(-0.5, 0.5, W), indexing="ij")
grid_radius = torch.sqrt(grid_x ** 2 + grid_y ** 2) + self.eps
L = (1.0 / (1 + (2 * np.pi * grid_radius * self.sigma) ** 4)).type_as(input)[None, None]
grad_img1 = self.gradient_norm(input)
grad_img1 = self.lowpass_filtering(grad_img1, L)
img_low = self.lowpass_filtering(input, L)
grad_img2 = self.gradient_norm(img_low)
grad_img2 = self.lowpass_filtering(grad_img2, L)
diff = grad_img1 - grad_img2
flag = torch.abs(grad_img1)
diff = torch.where(flag > 1, diff / flag.clamp_min(self.eps), torch.zeros_like(diff))
weight = (diff - self.cmin) / (self.cmax - self.cmin)
weight = torch.clamp(weight, 0, 1)
cartoon = weight * img_low + (1 - weight) * input
texture = (input - cartoon + self.lim) * 255 / (2 * self.lim)
texture = torch.clamp(texture, 0, 255)
return texture
Could you share your pytorch version? I would make sure it is the latest torch-nightly build.
This might be the problem. Using the package dunder I got the following:
>>> torch.__version__
'2.1.1+cu118'
I`ll try it out tomorrow using the latest torch-nightly build.
I installed the torch-nightly build and still the same Warning regarding ShapeInferenceError
and TypeInferenceError
. The only difference now is that the indexing of the nodes changed
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_149) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_148): B has inconsistent type tensor(int64
I changed the backbock from a manually built resnet to the backbone from timm
and retrained the model. The new model architecture seems to have a way straighter graph (checked on Netron), but it still has the same problem occuring in Div_148
. Here is the new model.
Python 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import onnxruntime as ort
>>> ort.InferenceSession("best.onnx")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_7) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_148): B has inconsistent type tensor(int64)
>>>
Also noticed the following UserWarning from torch.onnx.dynamo_export
that might be giving a clue that something is wrong during graph build.
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/onnx/_internal/fx/passes/readability.py:53: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
new_node = self.module.graph.get_attr(normalized_name)
Looks like we need a castlike here
cc @fatcat-z
@luisfmnunes do you think the model layer is using linspace there?
@luisfmnunes do you think the model layer is using linspace there?
Most Likely, the nn.Module
that seems to be the problem is the FastCartoonTexture
class and if we look at its forward function, meshgrid is built based on linspace tensor.
class FastCartoonTexture(nn.Module):
def __init__(self, sigma=2.5, eps=1e-6) -> None:
super().__init__()
self.sigma = sigma
self.eps = eps
self.cmin = 0.3
self.cmax = 0.7
self.lim = 20
self.img_grad = ImageGradient()
def lowpass_filtering(self, img, L):
img_fft = torch.fft.fftshift(torch.fft.fft2(img), dim=(-2, -1)) * L
img_rec = torch.fft.ifft2(torch.fft.fftshift(img_fft, dim=(-2, -1)))
img_rec = torch.real(img_rec)
return img_rec
def gradient_norm(self, img):
Gx, Gy = self.img_grad(img)
return torch.sqrt(Gx ** 2 + Gy ** 2) + self.eps
def forward(self, input):
H, W = input.size(-2), input.size(-1)
grid_y, grid_x = torch.meshgrid(torch.linspace(-0.5, 0.5, H), torch.linspace(-0.5, 0.5, W), indexing="ij")
grid_radius = torch.sqrt(grid_x ** 2 + grid_y ** 2).type_as(input) + self.eps
L = (1.0 / (1 + (2 * np.pi * grid_radius * self.sigma) ** 4)).type_as(input)[None, None]
grad_img1 = self.gradient_norm(input)
grad_img1 = self.lowpass_filtering(grad_img1, L)
img_low = self.lowpass_filtering(input, L)
grad_img2 = self.gradient_norm(img_low)
grad_img2 = self.lowpass_filtering(grad_img2, L)
diff = grad_img1 - grad_img2
flag = torch.abs(grad_img1)
diff = torch.where(flag > 1, diff / flag.clamp_min(self.eps), torch.zeros_like(diff))
weight = (diff - self.cmin) / (self.cmax - self.cmin)
weight = torch.clamp(weight, 0, 1)
cartoon = weight * img_low + (1 - weight) * input
texture = (input - cartoon + self.lim) * 255 / (2 * self.lim)
texture = torch.clamp(texture, 0, 255)
return texture
Great - we will need to fix linspace. If you like please feel free to add a cast like in the lines above and see if it gives you the correct model. I will create a fix this week.
Great - we will need to fix linspace. If you like please feel free to add a cast like in the lines above and see if it gives you the correct model. I will create a fix this week.
Cool, the CastLike
solves the problem of Div, but right after the Mul on return also has inconsistent type (due to start and range_tensors type mismatch probably). if I cast the range_tensor also like start, the following error happens later on the graph:
>>> ort.InferenceSession("best.onnx")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 463, in _create_inference_session
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (_inline_models_model_zoo_DenseHoughVoter_voter_1aten_rsub_186) Op (aten_rsub) [ShapeInferenceError] (op_type:Sub, node name: n3): B has inconsistent type tensor(int64)
Edit: The problem is still the linspace. When I gave the keyword dtype
to linspace functions I got a correct model.
Thank you very much for all your attention and help @justinchuby. I'll check what is failing on my PR and try to contribute with the aten::roll for complex and the var (dim and correction).
FYI if you find the functions with too many if branches are bothering you because of performance, you may consider https://github.com/microsoft/onnxscript/pull/1178
Well, despite the correctness of the model (meaning it has a valid graph) it seems like the parameters of the model were not loaded. Either if a I load the weights previously using nn.Module.load_state_dict
or passing it to torch.onnx.dynamo_export
kwargs.The results seems preety much random and diverging a lot from the torch model output.
Could you share a script you use for export and comparison?
Sure.
The export script:
import torch
import onnx
import yaml
from torch.onnx import export
from pathlib import Path
from loguru import logger
from argparse import ArgumentParser
from models.model_zoo import GRIDNET4, GRIDTIMMNET4
def main(args):
torch.set_default_device("cpu")
root = args.model_dir
config_file = root / "configs.yaml"
checkpoint = root / args.model_version
logger.info(f"Exporting model {root} to ONNX")
onnx_model = checkpoint.parent / checkpoint.with_suffix(".onnx").name
logger.info(f"Loading Model config file {config_file}")
with open(config_file, "r") as f:
config = yaml.load(f.read(), yaml.Loader)
if config["exp_name"] == "gridnet4":
model = GRIDNET4(
num_pose_2d=config["num_pose_2d"],
num_layers=config["num_layers"],
img_ppi=config["img_ppi"],
middle_shape=config["middle_shape"],
with_tv=config["with_tv"],
with_enh=config["with_enh"],
bin_type=config["bin_type"],
activate=config["activate"],
pretrained=False
)
else:
model = GRIDTIMMNET4(**config)
logger.info(f"Loading model checkpoint {checkpoint}")
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))
# model.load_state_dict(state_dict["model"]) # Load model state_dict previously
model.eval()
logger.info(f"Exporting model to {onnx_model}")
# TorchDynamo exports correctly but there are still unsupported onnxscript ops.
# Waiting Issue Response to solve this.
# onnx_export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_prog = torch.onnx.dynamo_export(
model, torch.randn(1, 1, 512, 512, dtype=torch.float32),
# export_options=onnx_export_options
)
# onnx_prog.save(onnx_model.as_posix()) #version without model dict load
onnx_prog.save(onnx_model.as_posix(), model_state_dict=state_dict["model"])
onnx_model = onnx.load(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))
onnx.checker.check_model(onnx_model)
def parse_args():
parser = ArgumentParser()
parser.add_argument("model_dir", help="Parent directory from model", type=Path)
parser.add_argument("--model_version", help="Version of exported checkpoint", type=str, default="best.pth")
return parser.parse_args()
if __name__ == "__main__":
main(parse_args())
The Comparison Script (Edit added function to compare the parameters)
import cv2
import yaml
import torch
import onnx
import numpy as np
import onnxruntime as ort
from pathlib import Path
from loguru import logger
from argparse import ArgumentParser
from deploy_gridnet import process_img
from models.model_zoo import GRIDNET4, GRIDTIMMNET4
def get_onnx_tensor_dict(onnx_load):
return {t.name: onnx.numpy_helper.to_array(t) for t in onnx_load.graph.initializer}
def compare_onnx_graph_and_state_dict(onnx_dict, state_dict):
torch_keys = [k for k in state_dict.keys() if k not in onnx_dict]
onnx_keys = [k for k in onnx_dict.keys() if k not in state_dict]
for k, v in onnx_dict.items():
if k in onnx_keys: continue
is_close = np.isclose(
v,
state_dict[k].numpy()
)
if not is_close.all():
logger.warning(
f"Parameter {k} is Divirging. {is_close}"
)
logger.warning(f"ONNX Keys not in PyTorch {onnx_keys}")
logger.warning(f"PyTorch Keys not in ONNX {torch_keys}")
def main(args):
config_file = args.root / "configs.yaml"
logger.info(f"Reading config file {config_file}")
with open(config_file, "r") as f:
config = yaml.load(f, yaml.Loader)
if config.get("architecture", None):
model = GRIDTIMMNET4(**config)
else:
model = GRIDNET4(**config)
checkpoint_file = args.root / f"{args.name}.pth"
onnx_file = checkpoint_file.with_suffix(".onnx")
logger.info(f"Loading Checkpoint {checkpoint_file}")
checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint["model"])
model.eval()
logger.info(f"Loading ONNX Model {onnx_file}")
onnx_load = onnx.load(onnx_file)
onnx_model = ort.InferenceSession(onnx_file)
logger.info(f"Comparing ONNX Graph and PyTorch State Dict")
compare_onnx_graph_and_state_dict(
get_onnx_tensor_dict(onnx_load),
checkpoint["model"],
)
image_file = Path(__file__).parent / "image/1_1.tif"
logger.info(f"Loading image {image_file}")
# im = cv2.imread(
# image_file.as_posix(),
# cv2.IMREAD_GRAYSCALE,
# ).astype(np.float32)
im = np.random.randn(512, 512).astype(np.float32)
# im, _, _ = process_img(im, 500, None)
print(im.shape)
im_tensor = torch.from_numpy(im)[None][None]
print(im_tensor.shape)
with torch.no_grad():
torch_out = model(im_tensor)
onnx_out = onnx_model.run(["cat_4"], {"l_input_": im[None, None]})
print(torch_out["pose_2d"])
print(onnx_out)
def parse_args():
parser = ArgumentParser()
parser.add_argument("root", type=Path, help="Path to output models")
parser.add_argument(
"--name", "-n", type=str, help="Name of model file", default="best"
)
return parser.parse_args()
if __name__ == "__main__":
main(parse_args())
Edit: Checking manually some parameters, they seem to have been converted fine. I'll make a script to verify if everything is in order.
Edit2 : Seems like there are no deviations in parameters, leading me to believe that the graph is incorrect. Will have to verify it further.
Edit3: I always get the same result independently of the input in the ONNX Model. I also get a broadcast error if I set dynamic shapes and give a shape different from the one used in export function ([1, 1, 512, 512]
). I'll upload the model if you wish to check it out. Model
(384, 384)
torch.Size([1, 1, 384, 384])
2023-11-29 16:50:38.636816288 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Mul node. Name:'_inline_aten_mul_complex_token_108n19' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 384 by 512
Traceback (most recent call last):
File "/home/griaule/Fingerprint-2DPose-Dense-Voting/onnx_sanity_check.py", line 104, in <module>
main(parse_args())
File "/home/griaule/Fingerprint-2DPose-Dense-Voting/onnx_sanity_check.py", line 78, in main
onnx_out = onnx_model.run(
^^^^^^^^^^^^^^^
File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
return self._sess.run(output_names, input_feed, run_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:'_inline_aten_mul_complex_token_108n19' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 384 by 512
Edit4: Comparing the ONNX Graph and the State Dict, the following result is seen:
Edit5: Seems like I screwed up some reference in the aten_linspace using CastLike. When I removed the changes from aten_linspace and forced torch.float32 in the linspace calls, I got a result very similar to PyTorch.
So we are good?
So we are good?
Yeah we good, there might be a problem in complex operations which is mentioned in Edit4, but it is all good, if eventually I need it to handle any shape without preprocessing, I open a new Issue. I'm closing this now.
Thank you once again @justinchuby.
Hello,
First of all, sorry for this post, I'm still kind lost on how ONNX opset 18 works and how TorchDynamo exports the model to an ONNX protobuf. Well I trained my model and now I'm trying to export to ONNX. Using torch.export.export I can generate an ExportedProgram with the following signature:
Unfortunatelly this architecture is resulting in the error:
torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.roll.default', 'aten.var.correction']}.
Is there any guideline on how to solve this problem and implement the support for the aforementioned operations? Thank you, and sorry for the long post.