Closed confucianism72 closed 7 months ago
the memory bottleneck is largely caused by attention matrix and softmax.
in natten2d.py
attn = natten2dqkrpb(q, k, self.rpb, self.kernel_size, self.dilation)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
You mentioned that newer version of NATTEN is in your recent working schedule.
I recommend considering using FlashAttention which is also supported in PyTorch standard attention operator.
As far as I know, FlashAttention does not construct one attention matrix explicitly, thus memory is saved.
Thank you for so much for looking into this;
Yes the memory overhead is expected, not just of Neighborhood Attention, but any BMM-style implementation of attention, because they store attention weights in global memory.
Fused kernels (like FlashAttention) are our primary goal, and we will likely not continue to maintain our standard BMM-style kernels (naive, tiled, and gemm kernels), but those will likely be kept in NATTEN for older architectures.
Fused neighborhood attention is now available (#111).
Here is one memory snapshot illustrated using torch.cuda.memory
I am testing memory usage of NATTEN.
My code is here.