triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.47k stars 1.66k forks source link

Optimizing Shared Memory Usage #4756

Open karan-dalal opened 2 months ago

karan-dalal commented 2 months ago

Hi. I am working on writing a Triton kernel for the backward pass of a sub-quadratic attention architecture. Currently, I'm receiving the following error when compiling the kernel:

triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 167936, Hardware limit: 166912. Reducing block sizes or `num_stages` may help.

The operations involved in the kernel are complex, and I have many loads and intermediate variables created during the derivation. I had a few questions on the SRAM usage inside the kernel:

Note: I'm using a simple grid of shape [Batch, Heads] (like Flash Attention). I don't think blocks or num stages is relevant.

I'm also happy to share the kernel code, if needed. Hopefully there's some way I can re-arrange operations and evict from SRAM to optimize usage.

Li-dongyang commented 1 month ago

Any progress? I am quite interested in this Shared Memory Usage issue. Would it be possible for you to share your code?

binarman commented 6 days ago

@karan-dalal @Li-dongyang good day!

Triton is pretty smart in terms of shared memory usage. But it does not expose explicit handles to tweak it's usage, plus it consider load/store operations as an "anchors". If you have loads and stores, they will stay there forever, unless load results are unused. In general, moving load/store/dot operations around could help with shared memory problem, but I struggle to give a particular instructions.

Another way to address this problem is to try optimize shared memory usage in compiler. AMD backend has an optimization which potentially can help. Currently it is AMD specific, but I think we can move it to common code if it is useful. Could any one with this problem share a kernel code + parameters or triton gpu IR (it is saved in ~/.triton/cache/*/*.ttgir file)?