triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.69k stars 1.53k forks source link

Do tensors and corresponding operators such as reshape, ravel, dot in triton support >=3 dimensions? #1129

Open rayleizhu opened 1 year ago

rayleizhu commented 1 year ago

I need to slice a 3D volume and do something like batched matmul (torch.bmm), like below

a_h_offs = tl.program_id(0) * WIN_SIZE + tl.range(0, WIN_SIZE)
a_w_offs = tl.program_id(1)  * WIN_SIZE + tl.range(0, WIN_SIZE)
a_c_offs = tl.program_id(2)  * GROUP_SIZE + tl.range(0, GROUP_SIZE)

# 3D slice of a with shape (WIN_SIZE, WIN_SIZE, GROUP_SIZE)
a_slice = a_ptr + a_h_offs[ :, None, None] * a_h_stride + \
                              a_w_offs[None, :, None] * a_w_stride + \
                              a_c_offs[None, None, :] * a_c_stride       

# w is some 2D wight tensor with shape (GROUP_SIZE,  DIM)
w =  ... 

# out is assumed to be (WIN_SIZE, WIN_SIZE, DIM)
out = tl.dot(a_slice, w)

Is this doable? Or do I need to decompose it into 2D subproblems by myself?

BTW, triton is a great tool that may revolutionize operator customization in deep learning. It would be better if there is more detailed documentation 😃

Jokeren commented 1 year ago

3d ops maybe buggy in certain cases.

We will update the documentation soon

rayleizhu commented 1 year ago

So, with triton, it is safe to think in a way like CUDA-C (e.g. regard tl.dot as something like tensor core MMA ), right?

Jokeren commented 1 year ago

it is safe to think in a way like CUDA-C

Maybe I'm not getting your question clearly. The programming models are different, using CUDA you program each thread, block, and block cluster, but triton only allows you to specify the behavior of each block.

rayleizhu commented 1 year ago

You got it. My description abuses some term. Actually, I turn to triton exactly because I want to avoid tedious issues under thread block level. Thanks.On 1 Feb 2023, at 2:35 AM, Keren Zhou @.***> wrote:

it is safe to think in a way like CUDA-C

Maybe I'm not getting your question clearly. The programming models are different, using CUDA you program each thread, block, and block cluster, but triton only allows you to specify the behavior of each block.

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

rayleizhu commented 1 year ago

Once thinking in a CUDA-C way, I have one more question:

Should I take the shared memory into consideration? For example, do need to consider the tile size in a thread block so it won't overflow smem? How about registers?

Jokeren commented 1 year ago

Should I take the shared memory into consideration

No, triton handles it for you.

How about registers

Again, no.

Although sometimes the generated code may have suboptimal shared memory usage or register counts. In that case, you can submit another issue.

rayleizhu commented 1 year ago

Any suggestions to cope with 3D or 4D dimensional indexing in a safe way for the current version (triton 1.1)?

Basically, I need to load a 3D or 4D tile (e.g. a tile of shape (ch, h_tile, w_tile)) and then reshape it to 2D for matmul.

Ideally, I want to do it like the below:

offs_h = h_start + tl.arange(0, h_tile)[:, None, None]
offs_w = w_start + tl.arange(0, w_tile)[None, :, None]
offs_c = c_start +  tl.arange(0, w_tile)[None, None, :]

tile_ptrs = ptr + offs_h * stride_h + offs_w * stride_w + offs_c * stride_c
tile_ptrs = tl.reshape(tile_ptrs, (h_tile*w_tile, c_tile))
tile = tl.load(tile_ptrs)

Otherwise, I reduce it to 2D manually:

offs_h = h_start + tl.arange(0, h_tile)[:, None,]
offs_w = w_start + tl.arange(0, w_tile)[None, :]
offs_spatial = offs_h * stride_h + offs_w * stride_w 
offs_spatial = tl.ravel(offs_spatial) # (h_tile*w_tile, )

offs_c = (c_start +  tl.arange(0, c_tile) ) * stride_c # (c_tile, )

tile_ptrs = ptr + offs_spatial[:, None] + offs_c[None, :]
tile = tl.load(tile_ptrs)

Any recommendation?

Jokeren commented 1 year ago

If you don't have dot/trans ops in the code, I suppose you could still declare 3d/3d tensors. Haven't tested though.