SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
341 stars 25 forks source link

Performance compared to local-attention + masking #120

Open Chillee opened 5 months ago

Chillee commented 5 months ago

I saw https://github.com/SHI-Labs/NATTEN/issues/89

As far as I know both FAv2 and xFormers' FMHA support 1-D sliding window attention with causal masking, so you probably can use them for now, but again only when your token space is 1-D, and only when you're doing causal masking.

Can you elaborate some more on what the performance differences are between Natten and local/sliding-window attention? I understand that Natten is not a special-case of local/sliding window attention, but from my understanding, you should be able to implement natten using 1. local/sliding-window attention, and 2. a special attention mask.

Is that correct? In that case, I think you would have most of the performance benefits of natten (i.e. you skip much of the unnecessary computation).

alihassanijr commented 5 months ago

Apologies in advance if I misunderstood your question; I think you're suggesting that we implement neighborhood attention by just modifying the attention mask logic in FMHA/FAV2 and use the existing sliding window parameter?

If so, the short answer is we kind of are, but that only covers a very small portion of our use cases.

Long answer:

The main difference between implementations of sliding window attention in FAv2/FMHA and NATTEN is that we're more interested in multi-dimensional inputs (i.e. 2-D and 3-D feature maps representing image and video spaces.)

Put simply, neighborhood attention is to self attention what convolution is to fully-connected layers (think nn.Linear on the spatial axes as well as channels). Similarly it allows for parameters like dilation (but padding and stride don't really make sense.)

From an implementation point of view, NATTEN considers the attention problem as 2 GETTs (General Tensor-Tensor Contraction) instead of 2 GEMMs (General Matrix-Matrix multiply). This basically means we're trying to not make an assumption about the "row mode" being single dimensional (a sequence of tokens).

It is true that implementing 1-D neighborhood attention into FAv2/FMHA is relatively trivial (I'll note NA also supports dilation, but there's an easy hack to supporting dilation.)

However, 2-D and 3-D require major modifications to convert the GEMMs in FAv2/FMHA to GETTs, and that's kind of what we're after right now.

Our fused kernels are based on FMHA; and our 1-D kernels are relatively similar in latency compared to FMHA.

That does not hold for 2D and 3D, because those are now GETT problems, and that complicates the data iteration process some. Speaking mostly for the forward pass kernel, most of the additional latency would come from that. I'll go out on a limb and say that this will probably become a non-issue once we have kernels that target Hopper's TMA, because the Hopper architecture is specifically designed to handle GETTs, and even considers GEMMs as trivial GETTs (although I'm not too sure of this last bit; it's loosely based on Vijay Thakkar's talks.)

So we could probably accelerate 1-D neighborhood attention by making those small changes to FAv2, but it would fall short of supporting 2-D and 3-D problems.

What we're really hoping to allow with NATTEN in the future is not just multi-dimensional sliding windows, but enabling explicit spatio-temporal attention (one attention call, with causal masking across certain dimensions, and without across others.)

Chillee commented 5 months ago

I'm still somewhat confused. Is it not possible to implement 2D Natten with a 1D FMHA + attention mask?

image

To clarify, this diagram is showing a 2d Natten right? I think you should be able to implement it with something along the lines of

(pseudocode)

q: [Height, Width, Dim]
k: [Height, Width, Dim]
mask: [Height_q, Width_q, Height_k, Width_k] = some tortured expression for computing the masks of the 1
mask_flat = mask.view(height_q * width_q, height_k * width_k)
q_flat: [Height*Width, Dim] = q.flatten()
k_flat: [Height*Width, Dim] = k.flatten()

out = sdpa(q_flat, k_flat, v_flat, attention_mask=mask.)

There are some obvious disadvantages of this compared to Natten. Specifically,

  1. This has quadratic computation cost wrt height * width.
  2. It must materialize the entire attention mask.

But, if we use some variant of sliding-window attention, we can mostly resolve 1, right? Is there something I'm missing about why this would be horrendously inefficient?

alihassanijr commented 5 months ago

I'm still somewhat confused. Is it not possible to implement 2D Natten with a 1D FMHA + attention mask?

It is possible; I just very seriously doubt it'll have any advantages compared to FNA.

To clarify, this diagram is showing a 2d Natten right?

That diagram is showing 1D, 2D, and 3D. The idea is that the kernel tiling is changed according to the spatial rank, but the threadblock-level MMA is agnostic to that. Load and store is mostly what's affected.

I think you should be able to implement it with something along the lines of (pseudocode)

So technically you can implement all forms of neighborhood attention with attention masks; there's no doubt there. NA is just a subgraph of the self attention graph.

There are some obvious disadvantages of this compared to Natten. Specifically,

  1. This has quadratic computation cost wrt height * width.
  2. It must materialize the entire attention mask.

But, if we use some variant of sliding-window attention, we can mostly resolve 1, right? Is there something I'm missing about why this would be horrendously inefficient?

As you pointed out, materializing the attention mask kind of defeats the purpose of dropping the memory footprint with a fused kernel, and I'm pretty sure you can't get rid of the quadratic computation, because it's not as straightforward to modify the 1-D sliding window to replicate 2-D or 3-D. I actually haven't given this much thought recently, but my guess is that it won't be possible. Also consider that dilation might make that even more complicated. The sliding window aspect of NA makes these things difficult, but the behavior around the corners just makes it worse.

In addition to the extra memory overhead, reading an explicit attention mask from global memory will probably slow down an attention kernel more than the overhead that 2D and 3D NA add because of the GEMM -> GETT changes.

So just to clarify; 2D and 3D are not as performant as 1D in terms of bandwidth; meaning that if we flatten QKV across height and width, and call FMHA/FNA1d with window size = kernel_size ** 2, then it'll probably be faster that sending the original QKV through FNA2d.

However, that doesn't mean FNA2d is so much slower that an explicit attention mask would be faster.

It might be for a few small edge cases, but in general I think the surest way to reduce actual global memory reads and reduce FLOPs, and of course ensure its correctness, is to convert the problem into a GETT like convs.

why this would be horrendously inefficient?

The explicit attention mask would be my guess. Even if it's just [height * width, height * width] and just repeated across batch and head, and even if we were to figure out a way to get 1-D sliding windows to reduce the iterations over KV (and by extension the O(n^2) complexity), which I still don't think is possible, I'd still be skeptical that it would have any advantage over FNA kernels. At the very least the fact that we'd be reading from a tensor that cannot be guaranteed to be aligned might mean it'll probably be difficult for Ampere and Hopper.

Chillee commented 5 months ago

I actually haven't given this much thought recently, but my guess is that it won't be possible.

So, to clarify more explicitly, I plotted out how the attention mask would look like for a 2d image.

image

From my understanding, the primary computation all lies within some distance from the "diagonal" region, right? The bulk of the sparsity lies from not computing regions outside that diagonal region, which I believe local/sliding window attention can represent.

I agree that there is some further sparsity within the diagonal region that local/sliding window attention could not take advantage of. I assume FNA2d takes advantage of that sparsity as well?

alihassanijr commented 5 months ago

From my understanding, the primary computation all lies within some distance from the "diagonal" region, right? The bulk of the sparsity lies from not computing regions outside that diagonal region, which I believe local/sliding window attention can represent.

You are right; with a little bit of modification to a 1-D attention kernel, we should be able to cut off most of what's outside the diagonal region. I'm not so sure if it'll support dilation (because dilation would require holding more than the one stride value, and that forces the GEMM -> GETT conversion, so we'd wind up with FNA2d again), but we'll set that aside for now.

I agree that there is some further sparsity within the diagonal region that local/sliding window attention could not take advantage of. I assume FNA2d takes advantage of that sparsity as well?

Yes (kind of). It's all still subject to what your GEMM shape / tile size is. But effectively yes. At the SM level, the kernel attempts to load QKVs in a way that maximizes their spatial proximity, which would sort of minimize the sparsity of the number of attention weights that get masked out.

Chillee commented 5 months ago

Thanks for the clarifications!

I understand why FNA2d is implemented the way it is, but I'd nevertheless be curious about the performance compared to a local-attention.

In particular, I think the "materializing attention mask" issue can be resolved if you compute the attention mask within the kernel itself (thus never materializing in gmem). So, a good baseline comparison might be comparing FNA2d vs. "full" local-attention. It is not a full 1:1 comparison, because FNA2d can take advantage of some additional sparsity while "full" local self-attention doesn't need to handle the additional index computations for the mask, but might be a good baseline to compare against.

alihassanijr commented 5 months ago

That's fair; it is something I'm curious about myself. The only hitch is that if we want local attention to compute NA exactly, there's additional changes that we'd have to make to the kernel. And I'm still unsure whether it would support varying parameters and dilation.

I guess I can take a look at this once we wrap up the backward kernel for FNA.

In particular, I think the "materializing attention mask" issue can be resolved if you compute the attention mask within the kernel itself (thus never materializing in gmem). So, a good baseline comparison might be comparing FNA2d vs. "full" local-attention. It is not a full 1:1 comparison, because FNA2d can take advantage of some additional sparsity while "full" local self-attention doesn't need to handle the additional index computations for the mask, but might be a good baseline to compare against.

Could you clarify this? If we're modifying 1-D attention kernel to do 2-D NA with attention masking, it would still have to do extra indexing computation and checks (and the checks are more likely to contribute to latency than index computation because of branching). If we're not modifying the 1-D attention kernel, and comparing FNA2d to self attention with FMHA, I think we've already done that comparison in the paper.


By the way, I really appreciate all the feedback. God knows I'm not going to get it from conference reviews haha.

Chillee commented 5 months ago

Could you clarify this? If we're modifying 1-D attention kernel to do 2-D NA with attention masking, it would still have to do extra indexing computation and checks (and the checks are more likely to contribute to latency than index computation because of branching).

I mean that an "easy"-ish comparison to do would be to run a local/sliding window attention kernel and compare against FNA2d. It won't exactly match the semantics of FNA2d, but it would give an upper-bound of how efficient a local/sliding-window attention kernel would be compared to FNA2d.

By the way, I really appreciate all the feedback. God knows I'm not going to get it from conference reviews haha.

Haha, I think Natten is quite cool :) We've been thinking about how to handle more of these kinds of attention "variants" in PyTorch, and Natten is a pretty nice example.

alihassanijr commented 5 months ago

I mean that an "easy"-ish comparison to do would be to run a local/sliding window attention kernel and compare against FNA2d. It won't exactly match the semantics of FNA2d, but it would give an upper-bound of how efficient a local/sliding-window attention kernel would be compared to FNA2d.

Okay, yeah so I did try a few problem sizes before we released, but most of them were comparing really large kernel sizes. I'll definitely come up with a set of problem sizes and test 2D and 3D against 1D very soon.

Haha, I think Natten is quite cool :) We've been thinking about how to handle more of these kinds of attention "variants" in PyTorch, and Natten is a pretty nice example.

Oh wow, thank you 🙂 . Means a great deal coming from you. Our goal has always been to give easy to use and autograd compatible interfaces to all the kernels we've been developing, but there's definitely plenty of room to simplify.