Open drisspg opened 1 year ago
This should be possible to codegen. The basic idea would be to use code like: https://github.com/pytorch/pytorch/blob/c27c56a5c428837adfc8c3f5308b23c06dcda386/torch/_inductor/sizevars.py#L417-L424 To extract the size/stride/offset information from inductor indexing expressions, and use that to construct the needed block_ptr.
Note that conversion would not work for some cases. Indirect indexing and ModularIndexing are both not representable as block_ptrs, so we would need to detect those cases and fall back to regular pointers.
@eellison: there are some open issues regarding slowdowns due to block pointers, not clear if they've been fixed. MTIA is going to want block pointers
@ezyang if you mean this: https://github.com/openai/triton/issues/2301 it is closed - there was no actual bug. Could you indicate what other issues you may have encountered, if any?
If not, I can take a look into implementing this if that's ok @drisspg
I'll start with matmul and we can work our way up to general fused kernels.
Btw, to my understanding, the TMA feature on NVIDIA H100 GPU requires block ptrs to be used, and use of TMA has shown 10-20% uplift in perf on matmul. TMA is not implemented in the current template but block_ptr would be a first step.
Here is reference impl for Hopper. More readable version with less swiss-cheesing
Updating the triton mm template should be even easier than general codegen.
I think one issue with block pointers is if we're on hardware that requires block pointers to be used. It's easy to imagine Inductor adding a pass that lifts into block pointers when possible, but what if we're in cases where we can't?
It's easy to imagine Inductor adding a pass that lifts into block pointers when possible, but what if we're in cases where we can't?
In what scenario wouldn't you be able to use block ptrs?
Further, I think the idea is that we should eventually replace all raw ptrs with block ptrs?
In what scenario wouldn't you be able to use block ptrs?
Block pointers only support strided access patterns so scatter/gather, modular indexing (e.g. torch.roll
) won't work.
AFAIK the block shape also has to be a power of two in each dimension, which will force us to use a different memory access pattern even for strided tensors (not necessarily worse, just different). Say for example you have a non-contiguous tensor with shape (10, 300) and an XBLOCK
of 512, today each thread would load the same indices as if it were a contiguous tensor. With a block pointer, you would have to choose a block shape like (2, 256)
or (1, 512)
and whatever doesn't divide evenly into the block size will be masked out.
I think one issue with block pointers is if we're on hardware that requires block pointers to be used.
Currently, TMA on Triton is opt-in - block_ptrs merely enable it, but non-TMA data movement can proceeed without block_ptrs.
But yes, in the scenario where there is hardware (e.g. MTIA) requiring block_ptrs, it seems those cases mentioned by @peterbell10 cannot be generated. So it seems that it would be preferred if hardware can always support the standard ptrs.
But generally, would you say it would be a positive change where we can convert majority of codegen to use block ptrs? @Chillee
Block pointers also don't support indirect loads (e.g. embeddings, jagged tensors, scatter/gather, etc).
They also don't support modular indexing (which can come from views).
Summary
More curiosity and a tracking issue for me or someone else to pick up. Triton has introduced Block_pointer: https://triton-lang.org/main/getting-started/tutorials/08-experimental-block-pointer.html
Personally I think these are easier to work with/reason about than raw strides. I could be missing something fundamental to inductors loop level IR but would be interesting to see what would be needed to support usage in this for example: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py#L31-L91
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @Xia-Weiwen @ngimel