arogozhnikov / einops

Flexible and powerful tensor operations for readable and reliable code (for pytorch, jax, TF and others)
https://einops.rocks
MIT License
8.55k stars 352 forks source link

[backend bug] when using einops.repeat, RuntimeError: Unsupported value kind: Tensor #324

Closed NOTGOOOOD closed 5 months ago

NOTGOOOOD commented 5 months ago

Hi! I find a strange bug. My model works fine for training and validation using einops.repeat, but when I export onnx, it reports an error. I can only observe an error at backend = BackendSubclass() in _backends.py.

Describe the bug Traceback (most recent call last): File "okulo_a1_demo_tof.py", line 211, in net = create_net(args) File "okulo_a1_demo_tof.py", line 60, in create_net export_onnx(model=net, checkpoint=args.resume, fp_16=args.fp_16, img=img) File "/home/xufeng/project/LightConGR/utils/export.py", line 25, in export_onnx torch.onnx.export(model, File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 516, in export _export( File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1612, in _export graph, params_dict, torch_out = _model_to_graph( File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/jit/_trace.py", line 1310, in _get_trace_graph outs = ONNXTracedModule( File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/jit/_trace.py", line 138, in forward graph, out = torch._C._create_graph_by_tracing( File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/jit/_trace.py", line 129, in wrapper outs.append(self.inner(trace_inputs)) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward result = self.forward(*input, kwargs) File "/home/xufeng/project/LightConGR/models/gesture_transformer.py", line 54, in forward x, logits = self.DTNNetv2(x) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward result = self.forward(*input, kwargs) File "/home/xufeng/project/LightConGR/models/DTNv2.py", line 447, in forward sub_x = cls_token(sub_x) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward result = self.forward(input, kwargs) File "/home/xufeng/project/LightConGR/models/DTNv2.py", line 294, in forward cls_token = repeat(self.cls_token, '() n d -> b n d', b=B) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/einops/einops.py", line 641, in repeat return reduce(tensor, pattern, reduction="repeat", axes_lengths) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/einops/einops.py", line 518, in reduce backend = get_backend(tensor) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/einops/_backends.py", line 54, in get_backend backend = BackendSubclass() File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/einops/_backends.py", line 222, in init from . import _torch_specific # noqa File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/einops/_torch_specific.py", line 128, in allow_ops_in_compiled_graph() File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/einops/_torch_specific.py", line 107, in allow_ops_in_compiled_graph from torch._dynamo import allow_in_graph File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/init.py", line 64, in torch.manual_seed = disable(torch.manual_seed) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/decorators.py", line 50, in disable return DisableContext()(fn) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 410, in call (filename is None or trace_rules.check(fn)) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/trace_rules.py", line 3378, in check return check_verbose(obj, is_inlined_call).skipped File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/trace_rules.py", line 3361, in check_verbose rule = torch._dynamo.trace_rules.lookup_inner( File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/trace_rules.py", line 3442, in lookup_inner rule = get_torch_obj_rule_map().get(obj, None) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/trace_rules.py", line 2782, in get_torch_obj_rule_map obj = load_object(k) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/trace_rules.py", line 2811, in load_object val = _load_obj_from_str(x[0]) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/_dynamo/trace_rules.py", line 2795, in _load_obj_from_str return getattr(importlib.import_module(module), obj_name) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/importlib/init.py", line 127, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nested/_internal/nested_tensor.py", line 416, in _nt_view_dummy = NestedTensor( File "/home/xufeng/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/nested/_internal/nested_tensor.py", line 232, in __torch_function__ return func(args, **kwargs) RuntimeError: Unsupported value kind: Tensor

Your platform einops 0.8.0 onnx 1.16.0 onnxconverter-common 1.14.0 onnxmltools 1.12.0 onnxruntime-gpu 1.17.1 onnxsim 0.4.36 pytorch 2.3.0
pytorch-cuda 11.8 python 3.8

arogozhnikov commented 5 months ago

hi @NOTGOOOOD

that's where error goes from:

from torch._dynamo import allow_in_graph 

for some reason it fails on your machine. This import is called by einops, but problem seems unrelated to einops code.