onnx2torch incorrectly omits one weight layer when converting onnx into pytorch. This mis-conversion leads to an IndexError: Dimension out of range in node_converters/global_average_pool.py", line 41, in <lambda>.
Traceback (most recent call last):
File "/home/suhwan/grad_course/metadl/report.py", line 23, in
torch_model(input_torch)
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
return self._wrapped_call(self, *args, kwargs)
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in call
raise e
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in call
return super(self.cls, obj).call(*args, *kwargs) # type: ignore[misc]
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, kwargs)
File ".0", line 73, in forward
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(args, kwargs)
File "/home/suhwan/.local/lib/python3.10/site-packages/onnx2torch/node_converters/global_average_pool.py", line 46, in forward
return forward_lambda()
File "/home/suhwan/.local/lib/python3.10/site-packages/onnx2torch/node_converters/global_average_pool.py", line 41, in
forward_lambda = lambda: torch.mean(input_tensor, dim=self._x_dims, keepdim=True)
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
Description
onnx2torch incorrectly omits one weight layer when converting onnx into pytorch. This mis-conversion leads to an
IndexError: Dimension out of range
innode_converters/global_average_pool.py", line 41, in <lambda>
.Steps to Reproduce
poc.zip
poc_onnx_model_path = 'poc.onnx'
load onnx_model
onnx_model = onnx.load(poc_onnx_model_path)
check model validity
onnx.checker.check_model(onnx_model)
input
input_torch = torch.randn(1, 3, 512, 512) input_ort = {'input': input_torch.numpy()}
no error in onnx
ort_session = onnxruntime.InferenceSession(poc_onnx_model_path) output_ort = ort_session.run(None, input_ort)
torch_model = convert(onnx_model)
error occurs!
torch_model(input_torch)
Traceback (most recent call last): File "/home/suhwan/grad_course/metadl/report.py", line 23, in
torch_model(input_torch)
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
return self._wrapped_call(self, *args, kwargs)
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in call
raise e
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in call
return super(self.cls, obj).call(*args, *kwargs) # type: ignore[misc]
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, kwargs)
File ".0", line 73, in forward
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(args, kwargs)
File "/home/suhwan/.local/lib/python3.10/site-packages/onnx2torch/node_converters/global_average_pool.py", line 46, in forward
return forward_lambda()
File "/home/suhwan/.local/lib/python3.10/site-packages/onnx2torch/node_converters/global_average_pool.py", line 41, in
forward_lambda = lambda: torch.mean(input_tensor, dim=self._x_dims, keepdim=True)
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)