triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.89k stars 1.56k forks source link

Optimizing Shared Memory Usage #4756

Open karan-dalal opened 1 week ago

karan-dalal commented 1 week ago

Hi. I am working on writing a Triton kernel for the backward pass of a sub-quadratic attention architecture. Currently, I'm receiving the following error when compiling the kernel:

triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 167936, Hardware limit: 166912. Reducing block sizes or `num_stages` may help.

The operations involved in the kernel are complex, and I have many loads and intermediate variables created during the derivation. I had a few questions on the SRAM usage inside the kernel:

Note: I'm using a simple grid of shape [Batch, Heads] (like Flash Attention). I don't think blocks or num stages is relevant.

I'm also happy to share the kernel code, if needed. Hopefully there's some way I can re-arrange operations and evict from SRAM to optimize usage.

Li-dongyang commented 6 days ago

Any progress? I am quite interested in this Shared Memory Usage issue. Would it be possible for you to share your code?