Open leademeule opened 5 months ago
Reproduced again on bcf367897694c182d52261147370009006ea4936 with the following traceback with TORCHDYMANO_VERBOSE=1
:
Traceback (most recent call last):
File "$REDACTED_FILE_PATH", line 83, in <module>
output = torch.compile(cast)(input, dtype_output)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "$REDACTED_FILE_PATH", line 66, in cast
cast_kernel[(triton.cdiv(LENGTH, BLOCK_SIZE),)](
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
result = inner_convert(
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/cvmfs/ai.mila.quebec/apps/arch/distro/python/3.10/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1260, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 562, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 944, in call_method
).call_function(tx, args, kwargs)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 908, in call_function
"kwargs": meta.as_proxy(),
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 143, in as_proxy
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 143, in <dictcomp>
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 301, in as_proxy
raise NotImplementedError(str(self))
torch._dynamo.exc.InternalTorchDynamoError: UserDefinedObjectVariable(dtype)
from user code:
File "$REDACTED_TRITON_PATH/python/triton/runtime/jit.py", line 327, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Here is the system configuration this error was reproduced on:
os: Ubuntu 22.04.3 LTS
cpu: AMD EPYC 7413
gpu: NVIDIA A100-SXM4-40GB
cuda driver: 535.161.08
cuda version: 12.2
python: 3.10.11
torch: 2.3.1
triton: bcf367897694c182d52261147370009006ea4936
As the title states, calling
torch.compile
fails when Triton kernel arguments includetriton.language.dtype
. This is demonstrated in the code sample below. The snippet below is somewhat silly, but the issue is unweildy in practice as it makes it impossible to compile certain kernels that require an accumulator of a data type that depends on the input without a lot of code duplication.This yields the following traceback:
Here is the system configuration this error was reproduced on: