Closed tfogal closed 3 months ago
[rank0]: File "/home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py", line 914, in forward
[rank0]: return self.mm_projector(x)
we could likely print self.mm_projector.weight.dtype
and x.dtype
to figure out what we get without dtype.
I am seeing a different error related to advanced indexing.
[rank0]: File "thunder/core/proxies.py", line 1333, in __getitem__
[rank0]: return method(self, key)
[rank0]: File "thunder/core/symbol.py", line 268, in __call__
[rank0]: result = self.meta(*args, **kwargs)
[rank0]: File "thunder/core/langctxs.py", line 132, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: File "thunder/torch/__init__.py", line 890, in getitem
[rank0]: return clang.getitem(a, key)
[rank0]: File "thunder/core/langctxs.py", line 132, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: File "thunder/clang/__init__.py", line 868, in getitem
[rank0]: return _advanced_indexing(a, key)
[rank0]: File "thunder/core/langctxs.py", line 132, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: File "thunder/clang/__init__.py", line 729, in _advanced_indexing
[rank0]: utils.check(
[rank0]: File "thunder/core/baseutils.py", line 103, in check
[rank0]: raise exception_type(s())
[rank0]: RuntimeError: Advanced indexing currently only supports zero or one-dimensional integer tensors, but found a tensor with dtype int64 and 2 dimensions
thunder commit used - 72e033a0e0dfe44d4770dec2399a9058971003ec
Full Log: neva.log
triage review
I have been able to repro the failure with an independent script. The failure happens due to the interaction of autocast and mixed input dtypes.
import thunder
import torch
def foo(x, w):
return torch.nn.functional.linear(x, w)
device = torch.device("cuda")
with device:
# Mixed input types.
x, w = torch.randn(16, 16, dtype=torch.bfloat16), torch.randn(16, 16)
# Same input types (works with thunder)
# x, w = torch.randn(16, 16), torch.randn(16, 16)
print(x.dtype, w.dtype)
with torch.autocast("cuda", torch.bfloat16):
# Eager autocast handles mixed input types.
eager_out = foo(x, w)
# `thunder.jit` doesn't handle mixed inputs.
jfoo = thunder.jit(foo)
jit_out = jfoo(x, w)
print(thunder.last_traces(jfoo)[-1])
torch.testing.assert_close(eager_out, jit_out)
I have been able to repro the failure with an independent script.
Great! Thank you, excellent work :-)
The reason it fails currently is because, while tracing with thunder.jit -
With mixed input dtypes, we fail at step 1 as these operators don't allow mixed inputs.
(In eager, with the context manager active, dispatcher first applies the conversion before passing the converted inputs to the operators).
Potential Fix:
@t-vi I would like your opinion on the same and some pointers. Thank you!
Great analysis @kshitij12345 !
For 1: We do have autocast handling in thunder.jit and cache_info.
For 2: To my mind, this is a thunder.torch thing more than something specific to jit_ext
, so I would probably look at trying to handle it in thunder.torch.torchsymbol
https://github.com/Lightning-AI/lightning-thunder/blob/da23a0b0e9ad17568be8566ad839ca0b0e88043b/thunder/torch/__init__.py#L108
WDYT?
🐛 Bug
Full log of the run that includes the unabbreviated traceback.
To Reproduce
Note you'll need the referenced
./data
directory.Expected behavior
Environment
cc @crcrpar @tfogal