triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.06k stars 1.43k forks source link

Support for tl shift operator on tensor #3954

Open thumbe3 opened 2 months ago

thumbe3 commented 2 months ago

Problem Statement: Stencil Sum Computation on GPU. i.e output at each index should be sum of values from input indices [index - RADIUS, index + RADIUS] where RADIUS is a constant known during compilation of kernel. Size of input is same as size of output. For border elements of output, input values can be considered zero for indices that are out of bounds for the range [cur_index - RADIUS, cur_index + RADIUS]

The solution shown below is to load BLOCK_SIZE elements 2 RADIUS + 1 times to load 2 RADIUS + 1 tensors, add all of the tensors, and then store it in output's global memory. However, this is expensive in terms of the load operations from global memory.

@triton.jit
def stencil_kernel(
    inputs: tl.tensor,
    outputs: tl.tensor,
    shape: tl.int32,
    BLOCK_SIZE: tl.constexpr,
    RADIUS: tl.constexpr):

    pid = tl.program_id(0)
    # Starting Offsets
    base_offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    output_tensor = tl.zeros([BLOCK_SIZE], tl.float32)

    for offset in tl.static_range(-RADIUS, RADIUS + 1):
        offs = base_offs + offset
        input_ptrs = inputs + offs
        input_tensor = tl.load(input_ptrs,
            mask=offs < shape and offs >= 0,
            other=0.0)
        output_tensor += input_tensor

    tl.store(outputs + base_offs, output_tensor, mask=base_offs < shape)

I found that when increasing the RADIUS to larger values, we dont get as good performance as hand tuned cuda kernel similar to shown here https://github.com/olcf/cuda-training-series/blob/master/exercises/hw2/stencil_1d_solution.cu.

Is there a better way to do this in triton?

What might help is a support for shift operator in a tl.tensor and in that case, You can load the BLOCK_SIZE + 2 RADIUS elements in a tensor once and then do something like a tl.shift() on tensor to get 2 RADIUS other tensors and perform addition of all of these. Maybe it can be lowered similar to what CUDA does with shfl_down_sync?

yunjiangster commented 1 month ago

Maybe you can achieve something similar by loading a 2d tensor (via unsqueeze and broadcasting) with overlapping elements. Here is an example:

import triton
import triton.language as tl
import torch

@triton.jit
def test_stencil(x_ptr, o_ptr):
    pid = tl.program_id(axis=0)
    rng = tl.arange(0, 4)
    x = tl.load(x_ptr + rng[:, None] + rng[None, :])
    tl.store(o_ptr + rng, tl.sum(x, axis=1))

x = torch.arange(8).cuda()
y = torch.zeros_like(x)
test_stencil[(1,)](x, y)
x, y

output looks like this

(tensor([0, 1, 2, 3, 4, 5, 6, 7], device='cuda:0'),
 tensor([ 6, 10, 14, 18,  0,  0,  0,  0], device='cuda:0'))
thumbe3 commented 1 month ago

Hi @yunjiangster, this is definitely a good idea to make code concise. However, it has the same problem of multiple loads. Here is the implementation I did by borrowing your idea. NUM_OFFSETS below is the next power of 2 for (2 * RADIUS + 1)

@triton.jit
def stencil_kernel_v2(
    inputs: tl.tensor,
    outputs: tl.tensor,
    shape: tl.int32,
    BLOCK_SIZE: tl.constexpr,
    NUM_OFFSETS: tl.constexpr,
    RADIUS: tl.constexpr):

    pid = tl.program_id(0)
    # Starting Offsets
    base_offs = tl.arange(0, BLOCK_SIZE) 
    inc_offs = tl.arange(0, NUM_OFFSETS) - RADIUS
    output_offsets = pid * BLOCK_SIZE + base_offs
    input_offsets = pid * BLOCK_SIZE + inc_offs[None, :] + base_offs[:, None]

    input_tensor = tl.load(inputs + input_offsets,
        mask=(input_offsets < shape and input_offsets >=0)
            and inc_offs[None, :] <= RADIUS,
        other=0.0)

    tl.store(outputs + output_offsets,
        tl.sum(input_tensor, axis=1),
        mask = output_offsets < shape)

Even this kernel has similar performance compared to the previous triton kernel and starts to lag behind the hand-tuned cuda implementation on high values of RADIUS

yunjiangster commented 1 month ago

@thumbe3 ah that makes sense. The load of the same repeated element will still require multiple loading work probably. I am curious how triton handles convolutional kernel then? It’s doing something similar.