get_viewmat with @torch_compile() on Nvidia Orin NX and PyTorch throws an error.
>>> import torch
>>> from nerfstudio.models import splatfacto
>>> x = torch.randn(1, 3, 4).to("cuda")
>>> splatfacto.get_viewmat(x)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
return fn(*args, **kwargs)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 571, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 701, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 408, in _convert_frame_assert
return _compile(
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 625, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
r = func(*args, **kwargs)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 542, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 149, in _fn
return fn(*args, **kwargs)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 507, in transform
tracer.run()
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2192, in run
super().run()
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in run
and self.step()
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 782, in step
getattr(self, inst.opname)(inst)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 471, in wrapper
return inner_fn(self, inst)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 246, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 560, in call_function
return wrap_fx_proxy(tx, proxy, **options)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1353, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1443, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1490, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1451, in get_fake_value
ret_val = wrap_fake_exception(
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 994, in wrap_fake_exception
return fn()
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1452, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1557, in run_node
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1536, in run_node
return node.target(*args, **kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function getitem>(*(FakeTensor(..., device='cuda:0', size=(3, 4)), (slice(None, None, None), slice(None, 3, None), slice(None, 3, None))), **{}):
too many indices for tensor of dimension 2
from user code:
File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/nerfstudio/models/splatfacto.py", line 104, in get_viewmat
R = optimized_camera_to_world[:, :3, :3] # 3 x 3
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
get_viewmat
with@torch_compile()
on Nvidia Orin NX and PyTorch throws an error.Python and package version: