triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.95k stars 1.58k forks source link

[Problem][Porposal] How does `make_block_ptr` work #3890

Open Felix-Zhenghao opened 4 months ago

Felix-Zhenghao commented 4 months ago

In the ops/flash_attention.py, K, V blocks are accessed through make_block_ptr. For example, I have a question:

The input tensors q, k, v are of size (Batch, n_head, seq_num, dim_per_head), but when we get K_block_ptr, we use shape=(BLOCK_DMODEL, Z_H_N_CTX) which is (dim_per_head, B*nh*T). So does make_block_ptr implicitly call something like Tensor.view?

Moreover, what does the order para mean?

However, I didn't find any tutorial on that - in the official tutorials, doc and the original commit #1392.

I suggest that some best practices of using make_block_ptr may be provided in the 06-Fused Attention tutorial. Introduce what's happening under the hood can be helpful. As the team said in the original commit:

Can you add a tutorial for the tile pointer? I think that could be helpful.

Originally posted by Jokeren in https://github.com/openai/triton/issues/1392#issuecomment-1481324741

jeejeelee commented 4 months ago

+1

yunjiangster commented 4 months ago

It seems tl.make_block_ptr can be simulated with some 2d pointer arithmetic and tl.arange.

Here is a simple example with make_block_ptr:

import triton
import triton.language as tl
import torch

@triton.jit
def block_kernel(x_ptr, o_ptr):
    absurd_shape: tl.constexpr = (2, -1)
    either_order: tl.constexpr = (1, 0)
    block_ptr = tl.make_block_ptr(base=x_ptr, shape=absurd_shape, strides=(2, 3), offsets=(0, 0), block_shape=(2, 2), order=either_order)
    rng = tl.arange(0, 2)
    offsets_2d = rng[:, None] + rng[None, :] * 5  # [0, 1]^T + [0, 5] = [[0, 1], [5, 6]]
    tl.store(o_ptr + offsets_2d, tl.load(block_ptr))

x = torch.arange(10).cuda()
o = x * 0
block_kernel[(1,)](x, o)  # block_kernel copies [x[0], x[2], x[3], x[5]] into [o[0], o[1], o[5], o[6]]
x, o

The output is

(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0'),
 tensor([0, 2, 0, 0, 0, 3, 5, 0, 0, 0], device='cuda:0'))

Note that shape and order parameters do not matter for the output, but may help the compiler optimize better.

The following pointer arithmetic implementation gives the same output

@triton.jit
def block_equivalent_kernel(x_ptr, o_ptr):
    rng = tl.arange(0, 2)
    x = tl.load(x_ptr + rng[:, None] * 2 + rng[None, :] * 3)
    tl.store(o_ptr + rng[:, None] + rng[None, :] * 5, x)

x = torch.arange(10).cuda()
o = x * 0
block_equivalent_kernel[(1,)](x, o)
x, o