Open thumbe3 opened 2 months ago
Maybe you can achieve something similar by loading a 2d tensor (via unsqueeze and broadcasting) with overlapping elements. Here is an example:
import triton
import triton.language as tl
import torch
@triton.jit
def test_stencil(x_ptr, o_ptr):
pid = tl.program_id(axis=0)
rng = tl.arange(0, 4)
x = tl.load(x_ptr + rng[:, None] + rng[None, :])
tl.store(o_ptr + rng, tl.sum(x, axis=1))
x = torch.arange(8).cuda()
y = torch.zeros_like(x)
test_stencil[(1,)](x, y)
x, y
output looks like this
(tensor([0, 1, 2, 3, 4, 5, 6, 7], device='cuda:0'),
tensor([ 6, 10, 14, 18, 0, 0, 0, 0], device='cuda:0'))
Hi @yunjiangster, this is definitely a good idea to make code concise. However, it has the same problem of multiple loads. Here is the implementation I did by borrowing your idea. NUM_OFFSETS below is the next power of 2 for (2 * RADIUS + 1)
@triton.jit
def stencil_kernel_v2(
inputs: tl.tensor,
outputs: tl.tensor,
shape: tl.int32,
BLOCK_SIZE: tl.constexpr,
NUM_OFFSETS: tl.constexpr,
RADIUS: tl.constexpr):
pid = tl.program_id(0)
# Starting Offsets
base_offs = tl.arange(0, BLOCK_SIZE)
inc_offs = tl.arange(0, NUM_OFFSETS) - RADIUS
output_offsets = pid * BLOCK_SIZE + base_offs
input_offsets = pid * BLOCK_SIZE + inc_offs[None, :] + base_offs[:, None]
input_tensor = tl.load(inputs + input_offsets,
mask=(input_offsets < shape and input_offsets >=0)
and inc_offs[None, :] <= RADIUS,
other=0.0)
tl.store(outputs + output_offsets,
tl.sum(input_tensor, axis=1),
mask = output_offsets < shape)
Even this kernel has similar performance compared to the previous triton kernel and starts to lag behind the hand-tuned cuda implementation on high values of RADIUS
@thumbe3 ah that makes sense. The load of the same repeated element will still require multiple loading work probably. I am curious how triton handles convolutional kernel then? It’s doing something similar.
Problem Statement: Stencil Sum Computation on GPU. i.e output at each index should be sum of values from input indices [index - RADIUS, index + RADIUS] where RADIUS is a constant known during compilation of kernel. Size of input is same as size of output. For border elements of output, input values can be considered zero for indices that are out of bounds for the range [cur_index - RADIUS, cur_index + RADIUS]
The solution shown below is to load BLOCK_SIZE elements 2 RADIUS + 1 times to load 2 RADIUS + 1 tensors, add all of the tensors, and then store it in output's global memory. However, this is expensive in terms of the load operations from global memory.
I found that when increasing the RADIUS to larger values, we dont get as good performance as hand tuned cuda kernel similar to shown here https://github.com/olcf/cuda-training-series/blob/master/exercises/hw2/stencil_1d_solution.cu.
Is there a better way to do this in triton?
What might help is a support for shift operator in a tl.tensor and in that case, You can load the BLOCK_SIZE + 2 RADIUS elements in a tensor once and then do something like a tl.shift() on tensor to get 2 RADIUS other tensors and perform addition of all of these. Maybe it can be lowered similar to what CUDA does with shfl_down_sync?