Open karan-dalal opened 2 months ago
Any progress? I am quite interested in this Shared Memory Usage
issue. Would it be possible for you to share your code?
@karan-dalal @Li-dongyang good day!
Triton is pretty smart in terms of shared memory usage. But it does not expose explicit handles to tweak it's usage, plus it consider load/store operations as an "anchors". If you have loads and stores, they will stay there forever, unless load results are unused. In general, moving load/store/dot
operations around could help with shared memory problem, but I struggle to give a particular instructions.
Another way to address this problem is to try optimize shared memory usage in compiler. AMD backend has an optimization which potentially can help.
Currently it is AMD specific, but I think we can move it to common code if it is useful.
Could any one with this problem share a kernel code + parameters or triton gpu IR (it is saved in ~/.triton/cache/*/*.ttgir
file)?
Hi. I am working on writing a Triton kernel for the backward pass of a sub-quadratic attention architecture. Currently, I'm receiving the following error when compiling the kernel:
The operations involved in the kernel are complex, and I have many loads and intermediate variables created during the derivation. I had a few questions on the SRAM usage inside the kernel:
tl.load
matter, or is Triton smart enough to compile it into the most memory optimal form. IE, can Itl.load
all required variables at the beginning and expect the same memory usage as if weretl.load
them right before the operation they were involved in?tl.store
andtl.load
in the same kernel, will this force triton to write it out to HBM and then reload it from HBM?x1 = tl.load(ptr)
and then later load another variable into itx1 = tl.load(ptr2)
will this overwrite the memory in SRAM?Note: I'm using a simple grid of shape [Batch, Heads] (like Flash Attention). I don't think blocks or num stages is relevant.
I'm also happy to share the kernel code, if needed. Hopefully there's some way I can re-arrange operations and evict from SRAM to optimize usage.