triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
11.91k stars 1.41k forks source link

Calling torch.compile fails when Triton kernel arguments include triton.language.dtype #4072

Open leademeule opened 1 month ago

leademeule commented 1 month ago

As the title states, calling torch.compile fails when Triton kernel arguments include triton.language.dtype. This is demonstrated in the code sample below. The snippet below is somewhat silly, but the issue is unweildy in practice as it makes it impossible to compile certain kernels that require an accumulator of a data type that depends on the input without a lot of code duplication.

import math
import torch
import triton

def cast_dtype_torch_to_dtype_triton(
    dtype: torch.dtype,
) -> triton.language.dtype:
    dtype_info = torch.finfo(dtype)
    if "float" in dtype_info.dtype:
        dtype_prefix = "fp"
    elif "int" in dtype_info.dtype:
        dtype_prefix = "int"
    else:
        raise ValueError("cast only supports basic integer or floating point types")
    dtype_string = f"{dtype_prefix}{dtype_info.bits}"
    return triton.language.dtype(dtype_string)

@triton.jit()
def cast_kernel(
    input_pointer,
    output_pointer,
    DTYPE: triton.language.constexpr,
    LENGTH: triton.language.constexpr,
    BLOCK_SIZE: triton.language.constexpr,
):
    program = triton.language.program_id(0)

    start = program * BLOCK_SIZE

    offsets = start + triton.language.arange(0, BLOCK_SIZE)

    mask = offsets < LENGTH

    input_loaded = triton.language.load(
        input_pointer + offsets,
        mask=mask,
    ).to(DTYPE)

    triton.language.store(
        output_pointer + offsets,
        input_loaded,
        mask=mask,
    )

def cast(
    input: torch.Tensor,
    dtype: torch.dtype,
) -> torch.Tensor:
    if not input.is_contiguous:
        raise ValueError("cast expects its input to be contiguous")
    if not input.device.type == "cuda":
        raise ValueError("cast expects its input to be on a CUDA device")

    output = torch.empty(
        input.shape,
        dtype=dtype,
        device=input.device,
    )

    LENGTH = math.prod(input.shape)
    BLOCK_SIZE = 64

    cast_kernel[(triton.cdiv(LENGTH, BLOCK_SIZE),)](
        input,
        output,
        cast_dtype_torch_to_dtype_triton(dtype),
        LENGTH,
        BLOCK_SIZE,
    )

    return output

device = torch.device("cuda")
dtype_input = torch.float32
dtype_output = torch.float16

input = torch.randn((64, 64), device=device, dtype=dtype_input)
# output = cast(input, dtype_output)
output = torch.compile(cast)(input, dtype_output)

assert torch.allclose(input.to(dtype_output), output)

This yields the following traceback:

Traceback (most recent call last):
  File "$REDACTED_FILE_PATH", line 83, in <module>
    output = torch.compile(cast)(input, dtype_output)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "$REDACTED_FILE_PATH", line 66, in cast
    cast_kernel[(triton.cdiv(LENGTH, BLOCK_SIZE),)](
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/cvmfs/ai.mila.quebec/apps/arch/distro/python/3.9/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1260, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/variables/misc.py", line 562, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 942, in call_method
    return TritonKernelVariable(
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 908, in call_function
    "kwargs": meta.as_proxy(),
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/variables/dicts.py", line 143, in as_proxy
    return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/variables/dicts.py", line 143, in <dictcomp>
    return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/torch/_dynamo/variables/base.py", line 301, in as_proxy
    raise NotImplementedError(str(self))
torch._dynamo.exc.InternalTorchDynamoError: UserDefinedObjectVariable(dtype)

from user code:
   File "$REDACTED_VIRTUAL_ENVIRONMENT_PATH/lib/python3.9/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Here is the system configuration this error was reproduced on:

os:           Ubuntu 22.04.3 LTS
cpu:          AMD EPYC 7413
gpu:          NVIDIA A100-SXM4-40GB
cuda driver:  535.161.08
cuda version: 12.2
python:       3.9.15
torch:        2.3.0
triton:       2.3.0
leademeule commented 1 month ago

Reproduced again on bcf367897694c182d52261147370009006ea4936 with the following traceback with TORCHDYMANO_VERBOSE=1:

Traceback (most recent call last):
  File "$REDACTED_FILE_PATH", line 83, in <module>
    output = torch.compile(cast)(input, dtype_output)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "$REDACTED_FILE_PATH", line 66, in cast
    cast_kernel[(triton.cdiv(LENGTH, BLOCK_SIZE),)](
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/cvmfs/ai.mila.quebec/apps/arch/distro/python/3.10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1260, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 562, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 944, in call_method
    ).call_function(tx, args, kwargs)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 908, in call_function
    "kwargs": meta.as_proxy(),
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 143, in as_proxy
    return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 143, in <dictcomp>
    return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "$REDACTED_TRITON_PATH/env/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 301, in as_proxy
    raise NotImplementedError(str(self))
torch._dynamo.exc.InternalTorchDynamoError: UserDefinedObjectVariable(dtype)

from user code:
   File "$REDACTED_TRITON_PATH/python/triton/runtime/jit.py", line 327, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Here is the system configuration this error was reproduced on:

os:           Ubuntu 22.04.3 LTS
cpu:          AMD EPYC 7413
gpu:          NVIDIA A100-SXM4-40GB
cuda driver:  535.161.08
cuda version: 12.2
python:       3.10.11
torch:        2.3.1
triton:       bcf367897694c182d52261147370009006ea4936