Hi. I am working on writing a Triton kernel for the backward pass of a sub-quadratic attention architecture. Currently, I'm receiving the following error when compiling the kernel:
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 167936, Hardware limit: 166912. Reducing block sizes or `num_stages` may help.
The operations involved in the kernel are complex, and I have many loads and intermediate variables created during the derivation. I had a few questions on the SRAM usage inside the kernel:
Does the order of tl.load matter, or is Triton smart enough to compile it into the most memory optimal form. IE, can I tl.load all required variables at the beginning and expect the same memory usage as if were tl.load them right before the operation they were involved in?
Is there a way to forcibly evict a variable from shared memory after loading it, if I no longer need to use it?
If I use tl.store and tl.load in the same kernel, will this force triton to write it out to HBM and then reload it from HBM?
If I load x1 = tl.load(ptr) and then later load another variable into it x1 = tl.load(ptr2) will this overwrite the memory in SRAM?
Is there a way to understand memory usage breakdown in a compiled kernel?
Note: I'm using a simple grid of shape [Batch, Heads] (like Flash Attention). I don't think blocks or num stages is relevant.
I'm also happy to share the kernel code, if needed. Hopefully there's some way I can re-arrange operations and evict from SRAM to optimize usage.
Hi. I am working on writing a Triton kernel for the backward pass of a sub-quadratic attention architecture. Currently, I'm receiving the following error when compiling the kernel:
The operations involved in the kernel are complex, and I have many loads and intermediate variables created during the derivation. I had a few questions on the SRAM usage inside the kernel:
tl.load
matter, or is Triton smart enough to compile it into the most memory optimal form. IE, can Itl.load
all required variables at the beginning and expect the same memory usage as if weretl.load
them right before the operation they were involved in?tl.store
andtl.load
in the same kernel, will this force triton to write it out to HBM and then reload it from HBM?x1 = tl.load(ptr)
and then later load another variable into itx1 = tl.load(ptr2)
will this overwrite the memory in SRAM?Note: I'm using a simple grid of shape [Batch, Heads] (like Flash Attention). I don't think blocks or num stages is relevant.
I'm also happy to share the kernel code, if needed. Hopefully there's some way I can re-arrange operations and evict from SRAM to optimize usage.