pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.5k stars 22.76k forks source link

Supporting Block_Ptrs in inductor code gen #109420

Open drisspg opened 1 year ago

drisspg commented 1 year ago

Summary

More curiosity and a tracking issue for me or someone else to pick up. Triton has introduced Block_pointer: https://triton-lang.org/main/getting-started/tutorials/08-experimental-block-pointer.html

Personally I think these are easier to work with/reason about than raw strides. I could be missing something fundamental to inductors loop level IR but would be interesting to see what would be needed to support usage in this for example: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py#L31-L91

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @Xia-Weiwen @ngimel

jansel commented 1 year ago

This should be possible to codegen. The basic idea would be to use code like: https://github.com/pytorch/pytorch/blob/c27c56a5c428837adfc8c3f5308b23c06dcda386/torch/_inductor/sizevars.py#L417-L424 To extract the size/stride/offset information from inductor indexing expressions, and use that to construct the needed block_ptr.

Note that conversion would not work for some cases. Indirect indexing and ModularIndexing are both not representable as block_ptrs, so we would need to detect those cases and fall back to regular pointers.

ezyang commented 1 year ago

@eellison: there are some open issues regarding slowdowns due to block pointers, not clear if they've been fixed. MTIA is going to want block pointers

jon-chuang commented 1 year ago

@ezyang if you mean this: https://github.com/openai/triton/issues/2301 it is closed - there was no actual bug. Could you indicate what other issues you may have encountered, if any?

If not, I can take a look into implementing this if that's ok @drisspg

I'll start with matmul and we can work our way up to general fused kernels.

jon-chuang commented 1 year ago

Btw, to my understanding, the TMA feature on NVIDIA H100 GPU requires block ptrs to be used, and use of TMA has shown 10-20% uplift in perf on matmul. TMA is not implemented in the current template but block_ptr would be a first step.

Here is reference impl for Hopper. More readable version with less swiss-cheesing

jansel commented 1 year ago

Updating the triton mm template should be even easier than general codegen.

Chillee commented 1 year ago

I think one issue with block pointers is if we're on hardware that requires block pointers to be used. It's easy to imagine Inductor adding a pass that lifts into block pointers when possible, but what if we're in cases where we can't?

jon-chuang commented 1 year ago

It's easy to imagine Inductor adding a pass that lifts into block pointers when possible, but what if we're in cases where we can't?

In what scenario wouldn't you be able to use block ptrs?

Further, I think the idea is that we should eventually replace all raw ptrs with block ptrs?

peterbell10 commented 1 year ago

In what scenario wouldn't you be able to use block ptrs?

Block pointers only support strided access patterns so scatter/gather, modular indexing (e.g. torch.roll) won't work.

AFAIK the block shape also has to be a power of two in each dimension, which will force us to use a different memory access pattern even for strided tensors (not necessarily worse, just different). Say for example you have a non-contiguous tensor with shape (10, 300) and an XBLOCK of 512, today each thread would load the same indices as if it were a contiguous tensor. With a block pointer, you would have to choose a block shape like (2, 256) or (1, 512) and whatever doesn't divide evenly into the block size will be masked out.

jon-chuang commented 1 year ago

I think one issue with block pointers is if we're on hardware that requires block pointers to be used.

Currently, TMA on Triton is opt-in - block_ptrs merely enable it, but non-TMA data movement can proceeed without block_ptrs.

But yes, in the scenario where there is hardware (e.g. MTIA) requiring block_ptrs, it seems those cases mentioned by @peterbell10 cannot be generated. So it seems that it would be preferred if hardware can always support the standard ptrs.

But generally, would you say it would be a positive change where we can convert majority of codegen to use block ptrs? @Chillee

jansel commented 1 year ago

Block pointers also don't support indirect loads (e.g. embeddings, jagged tensors, scatter/gather, etc).

They also don't support modular indexing (which can come from views).