Open Felix-Zhenghao opened 4 months ago
+1
It seems tl.make_block_ptr
can be simulated with some 2d pointer arithmetic and tl.arange
.
Here is a simple example with make_block_ptr
:
import triton
import triton.language as tl
import torch
@triton.jit
def block_kernel(x_ptr, o_ptr):
absurd_shape: tl.constexpr = (2, -1)
either_order: tl.constexpr = (1, 0)
block_ptr = tl.make_block_ptr(base=x_ptr, shape=absurd_shape, strides=(2, 3), offsets=(0, 0), block_shape=(2, 2), order=either_order)
rng = tl.arange(0, 2)
offsets_2d = rng[:, None] + rng[None, :] * 5 # [0, 1]^T + [0, 5] = [[0, 1], [5, 6]]
tl.store(o_ptr + offsets_2d, tl.load(block_ptr))
x = torch.arange(10).cuda()
o = x * 0
block_kernel[(1,)](x, o) # block_kernel copies [x[0], x[2], x[3], x[5]] into [o[0], o[1], o[5], o[6]]
x, o
The output is
(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0'),
tensor([0, 2, 0, 0, 0, 3, 5, 0, 0, 0], device='cuda:0'))
Note that shape
and order
parameters do not matter for the output, but may help the compiler optimize better.
The following pointer arithmetic implementation gives the same output
@triton.jit
def block_equivalent_kernel(x_ptr, o_ptr):
rng = tl.arange(0, 2)
x = tl.load(x_ptr + rng[:, None] * 2 + rng[None, :] * 3)
tl.store(o_ptr + rng[:, None] + rng[None, :] * 5, x)
x = torch.arange(10).cuda()
o = x * 0
block_equivalent_kernel[(1,)](x, o)
x, o
In the ops/flash_attention.py, K, V blocks are accessed through
make_block_ptr
. For example, I have a question:The input tensors
q, k, v
are of size (Batch, n_head, seq_num, dim_per_head), but when we getK_block_ptr
, we useshape=(BLOCK_DMODEL, Z_H_N_CTX)
which is (dim_per_head, B*nh*T). So doesmake_block_ptr
implicitly call something likeTensor.view
?Moreover, what does the
order
para mean?However, I didn't find any tutorial on that - in the official tutorials, doc and the original commit #1392.
I suggest that some best practices of using
make_block_ptr
may be provided in the06-Fused Attention
tutorial. Introduce what's happening under the hood can be helpful. As the team said in the original commit:Originally posted by Jokeren in https://github.com/openai/triton/issues/1392#issuecomment-1481324741