triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.34k stars 1.64k forks source link

Triton does not support tuple of input tensors for associative_scan as documented #2657

Closed PeaBrane closed 9 months ago

PeaBrane commented 12 months ago

I am trying to use associative_scan to compute the EMA of a signal, and my example is adapted from this issue, with the exception that I am trying to pass multiple tensors (as a tuple) to the associative_scan function.

import numpy as np
import torch
import triton
import triton.language as tl

@triton.jit
def _ema_op(element_1, element_2):
    input_1, A_1 = element_1
    input_2, A_2 = element_2

    input_new = input_1 * A_2 + input_2
    A_new = A_1 * A_2

    return (input_new, A_new)

@triton.jit
def kernel(
    values,
    factors,
    Z,
    length: tl.constexpr,
):
    global_id = tl.num_programs(axis=1) * tl.program_id(axis=0) + tl.program_id(axis=1)

    offsets = tl.arange(0, length) + global_id * length

    values_tensor = tl.load(values + offsets)
    factors_tensor = tl.load(factors + offsets)
    # vf = bit_merge(values_tensor, factors_tensor)

    out_values, _ = tl.associative_scan((values_tensor, factors_tensor), 0, combine_fn=_ema_op)

    tl.store(Z + offsets, out_values)

def cumulative_ema(
    values: torch.Tensor,
    factors: torch.Tensor,
) -> torch.Tensor:
    """
    Compute cumulative exponential moving average on last axis of rank-3 inputs.

    Args:
        values: [B, N, T] float32 values
        factors: [B, N, T] float32 decay factors

    Returns:
        cumulative ema values, same shape as values/factors.
    """
    assert len(values.shape) == 3, values.shape
    assert values.shape == factors.shape, (values.shape, factors.shape)

    shape = values.shape
    result = torch.empty_like(values)

    kernel[(shape[0], shape[1])](
        values,
        factors,
        result,
        length=shape[2],
    )
    return result

if __name__ == "__main__":
    shape = (1, 1, 1024 * 8)  # if this is (1, 1, 128) the result is incorrect even with the work-around cumulative_ema_op
    values = np.arange(np.prod(shape)).reshape(shape)
    factors = np.full(shape, 0.9)

    expected = np.zeros(shape, np.float32)
    expected[:, :, 0] = values[:, :, 0]
    for i in range(1, shape[2]):
        expected[:, :, i] = expected[:, :, i - 1] * factors[:, :, i] + values[:, :, i]
    print(expected)

    device = "cuda"
    values = torch.tensor(values, dtype=torch.float32, device=device)
    factors = torch.tensor(factors, dtype=torch.float32, device=device)

    result = cumulative_ema(values, factors)
    result = result.cpu()

    print(result.numpy())

According to the documentation, this function should be able to support a tuple of tensors, but I keep getting the error

Traceback (most recent call last):
  File "/home/ypei/codes/classy/test_scan.py", line 109, in <module>
    result = cumulative_ema(values, factors)
  File "/home/ypei/codes/classy/test_scan.py", line 85, in cumulative_ema
    kernel[(shape[0], shape[1])](
  File "<string>", line 4, in kernel
  File "/home/ypei/opt/anaconda3/envs/triton/lib/python3.10/site-packages/triton/runtime/jit.py", line 428, in launcher_body
    bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
  File "/home/ypei/opt/anaconda3/envs/triton/lib/python3.10/site-packages/triton/compiler/compiler.py", line 513, in compile
    next_module = compile_kernel(module)
  File "/home/ypei/opt/anaconda3/envs/triton/lib/python3.10/site-packages/triton/compiler/compiler.py", line 407, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
  File "/home/ypei/opt/anaconda3/envs/triton/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1156, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 15:61:    Z,
    length: tl.constexpr,
):
    global_id = tl.num_programs(axis=1) * tl.program_id(axis=0) + tl.program_id(axis=1)

    offsets = tl.arange(0, length) + global_id * length

    values_tensor = tl.load(values + offsets)
    factors_tensor = tl.load(factors + offsets)
    # vf = bit_merge(values_tensor, factors_tensor)

    z = tl.associative_scan((values_tensor, factors_tensor), 0, combine_fn=_ema_op)
                                                             ^
ValueError('Current implementation only support single tensor input')
ThomasRaoux commented 12 months ago

Correct we currently don't support multiple inputs. We don't have bandwidth to add support for this at the moment. I can fix the documentation to make it more clear. Contributions are welcome if anybody is interested.

PeaBrane commented 12 months ago

@ThomasRaoux I see, thank you.

Digging a bit deeper into the code, it seems like the associative scan is meant to be applied independently to all the tensors in the input tuple, and the tensors in the tuple do not "interact" with each other.

I am interested in helping add support for multiple inputs to associative scan, but I need to read a bit more into the relevant codes to see if I have the ability to contribute anything useful.

ThomasRaoux commented 12 months ago

@ThomasRaoux I see, thank you.

Digging a bit deeper into the code, it seems like the associative scan is meant to be applied independently to all the tensors in the input tuple, and the tensors in the tuple do not "interact" with each other.

I don't think the code makes any assumptions on what is done within the scan function. Basically we apply the region of the scan multiple times so what is happening within the region doesn't really matter. What needs to be done is passing multiple inputs and handling multiple accumulators in case the function has multiple results.

That being said I expect the change to be non-trivial as there are quite a bit of places to update.

I am interested in helping add support for multiple inputs to associative scan, but I need to read a bit more into the relevant codes to see if I have the ability to contribute anything useful.

Of course, no pressure.

lezcano commented 11 months ago

@PeaBrane have you already started working on this? Would you mind me having a stab at it?

PeaBrane commented 11 months ago

@lezcano No I have not. Please feel free to take on this! (Out or curiosity, would this be used as backend for torch.associative_scan?)

lezcano commented 11 months ago

Yep! You can find the PR that adds the initial plumbing within torch.compile at https://github.com/pytorch/pytorch/pull/106581 and https://github.com/pytorch/pytorch/pull/106581

isuruf commented 9 months ago

https://github.com/openai/triton/pull/2947 added support for a tuple of input tensors. Thanks @lezcano. It supports only tensors of the same type. When tensors have different types, triton fails with error: 'tt.scan' op requires the same element type for all operands and results

Script:

```python from ctypes import c_void_p, c_long import torch import math import random import os import tempfile from math import inf, nan from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided from torch._inductor.codecache import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall from torch._inductor.triton_heuristics import grid, get_interface_for_device aten = torch.ops.aten inductor_ops = torch.ops.inductor assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda alloc_from_pool = torch.ops.inductor._alloc_from_pool reinterpret_tensor = torch.ops.inductor._reinterpret_tensor async_compile = AsyncCompile() triton_per_fused_cummax_0 = async_compile.triton('triton_', ''' import triton import triton.language as tl from torch._inductor.ir import ReductionHint from torch._inductor.ir import TileHint from torch._inductor.triton_heuristics import AutotuneHint, persistent_reduction from torch._inductor.utils import instance_descriptor from torch._inductor import triton_helpers from triton.compiler.compiler import AttrsDescriptor @triton.jit def _triton_helper_fn0(value0, index0, value1, index1): gt = value0 > value1 return tl.where(gt, value0, value1), tl.where(gt, index0, index1) @persistent_reduction( size_hints=[8, 16], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*i64', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(), divisible_by_8=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_cummax_0', 'mutated_arg_names': [], 'no_x_dim': False} ) @triton.jit def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): xnumel = 5 rnumel = 10 RBLOCK: tl.constexpr = 16 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] roffset = 0 rmask = rindex < rnumel r1 = rindex x0 = xindex tmp2 = tl.full([1, 1], 0, tl.float32) tmp0 = tl.load(in_ptr0 + (r1 + (10*x0)), rmask & xmask, other=0.0) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp4 = tl.broadcast_to(rindex.to(tl.int64), [XBLOCK, RBLOCK]) tmp3 = tl.where(rmask & xmask, tmp1, tmp2) tmp5, tmp6 = tl.associative_scan((tmp3, tmp4), 1, _triton_helper_fn0) tl.store(out_ptr0 + (r1 + (10*x0)), tmp5, rmask & xmask) tl.store(out_ptr1 + (r1 + (10*x0)), tmp6, rmask & xmask) ''') async_compile.wait(globals()) del async_compile gpu_device = get_interface_for_device('cuda') stream = gpu_device.get_raw_stream(gpu_device.current_device()) def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (5, 10), (10, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = empty_strided_cuda((5, 10), (10, 1), torch.float32) buf1 = empty_strided_cuda((5, 10), (10, 1), torch.int64) triton_per_fused_cummax_0.run(arg0_1, buf0, buf1, 5, 10, grid=grid(5), stream=stream) del arg0_1 return (buf0, buf1, ) from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((5, 10), (10, 1), device='cuda:0', dtype=torch.float32) res = call([arg0_1]) print(res) print(torch.cummax(arg0_1, 1)) ```

Traceback:

```python loc("/tmp/torchinductor_isuruf/6q/c6qiysffvftxrospiguc4xxtxqqiej5t5hho4mcy3lz7bxtvipn6.py":41:54): error: 'tt.scan' op requires the same element type for all operands and results 'tt.scan' op requires the same element type for all operands and results concurrent.futures.process._RemoteTraceback: """ Traceback (most recent call last): File "/home/isuruf/.conda/envs/pytorch-dev/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker r = call_item.fn(*call_item.args, **call_item.kwargs) File "/home/isuruf/git/pytorch2/torch/_inductor/codecache.py", line 2370, in _worker_compile kernel.precompile(warm_cache_only_with_cc=cc) File "/home/isuruf/git/pytorch2/torch/_inductor/triton_heuristics.py", line 200, in precompile compiled_binary, launcher = self._precompile_config( File "/home/isuruf/git/pytorch2/torch/_inductor/triton_heuristics.py", line 345, in _precompile_config triton.compile(*compile_args, **compile_kwargs), File "/home/isuruf/.conda/envs/pytorch-dev/lib/python3.8/site-packages/triton/compiler/compiler.py", line 231, in compile next_module = compile_ir(module, metadata) File "/home/isuruf/.conda/envs/pytorch-dev/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 265, in stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) File "/home/isuruf/.conda/envs/pytorch-dev/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 115, in make_ttir pm.run(mod) RuntimeError: PassManager::run failed """ The above exception was the direct cause of the following exception: Traceback (most recent call last): File "cumsum.py", line 75, in async_compile.wait(globals()) File "/home/isuruf/git/pytorch2/torch/_inductor/codecache.py", line 2591, in wait scope[key] = result.result() File "/home/isuruf/git/pytorch2/torch/_inductor/codecache.py", line 2398, in result self.future.result() File "/home/isuruf/.conda/envs/pytorch-dev/lib/python3.8/concurrent/futures/_base.py", line 444, in result return self.__get_result() File "/home/isuruf/.conda/envs/pytorch-dev/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result raise self._exception RuntimeError: PassManager::run failed ```
ThomasRaoux commented 9 months ago

2947 added support for a tuple of input tensors. Thanks @lezcano. It supports only tensors of the same type. When tensors have different types, triton fails with error: 'tt.scan' op requires the same element type for all operands and results

Script:

Traceback:

this is probably just a missed restriction. It should be easy to make a patch that remove that restriction and adds a simple test to make sure it works.

lezcano commented 9 months ago

Will submit a fix either today or early next week.