nerfstudio-project / nerfstudio

A collaboration friendly studio for NeRFs
https://docs.nerf.studio
Apache License 2.0
8.87k stars 1.18k forks source link

Error when using get_viewmat with @torch_compile() on Nvidia Orin NX #3245

Open ymtoo opened 1 week ago

ymtoo commented 1 week ago

get_viewmat with @torch_compile() on Nvidia Orin NX and PyTorch throws an error.

>>> import torch
>>> from nerfstudio.models import splatfacto
>>> x = torch.randn(1, 3, 4).to("cuda")
>>> splatfacto.get_viewmat(x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 571, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 701, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 408, in _convert_frame_assert
    return _compile(
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 625, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 542, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 149, in _fn
    return fn(*args, **kwargs)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 507, in transform
    tracer.run()
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2192, in run
    super().run()
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in run
    and self.step()
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 782, in step
    getattr(self, inst.opname)(inst)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 471, in wrapper
    return inner_fn(self, inst)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 246, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 560, in call_function
    return wrap_fx_proxy(tx, proxy, **options)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1353, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1443, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1490, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1451, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 994, in wrap_fake_exception
    return fn()
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1452, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1557, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1536, in run_node
    return node.target(*args, **kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function getitem>(*(FakeTensor(..., device='cuda:0', size=(3, 4)), (slice(None, None, None), slice(None, 3, None), slice(None, 3, None))), **{}):
too many indices for tensor of dimension 2

from user code:
   File "/home/arl/miniconda3/envs/ChriisLink/lib/python3.10/site-packages/nerfstudio/models/splatfacto.py", line 104, in get_viewmat
    R = optimized_camera_to_world[:, :3, :3]  # 3 x 3

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Python and package version:

Python 3.10.14
nerfstudio 1.1.3
ymtoo commented 1 day ago

Xref: https://forums.developer.nvidia.com/t/error-when-using-torch-compile-on-jetson-orin-nx/297421