SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
375 stars 30 forks source link

Grouped Query Attention without repeat() #92

Open Birch-san opened 10 months ago

Birch-san commented 10 months ago

Grouped Query Attention improves parameter-efficiency of attention KV projections and reduces IO at inference-time, making inference faster.

It can be implemented naïvely by just repeating K and V along the head dim. I did so here:
https://github.com/Birch-san/natten-fwd-ad/blob/gqa/src/natten_block.py _I am not sure whether a repeat() or repeat_interleave() is preferred, but probably "whatever Llama did" can be considered the standard practice._

But perhaps instead of incurring the IO cost of repeat(): we could direct the matmul to visit the K kv_groups times?

For example, if we unflattened Q's head dim into (groups, heads):

Q=(1, q_heads)

we can do the same with KV, then expand() its groups dim by the number of kv_groups:

K=V=(1, kv_heads)
K=V=(kv_groups, kv_heads)

I'm not sure whether that expand() is free, but if it is: then we would need the NATTEN API to accept arguments with these kinds of shapes:

# [batch, groups, heads, hei, wid, channels]
Q=[1, 1, 6, 128, 128, 64]
K=[1, 2, 3, 128, 128, 64]
V=[1, 2, 3, 128, 128, 64]

where so long as [groups, heads] flattens to the same amount (6), the arguments would be allowed.

maybe the user would need to tell the NATTEN API whether to access the groups via a repeat() or repeat_interleaved() access pattern. or maybe only one of those access patterns makes sense to support perf-wise.

does it sound like a speedup is possible here compared to just naïve repeat?

_note: I don't think scaled_dot_product_attention supports this kind of thing (I tried it but it rejected the tensor shapes). so this isn't a parity item._

alihassanijr commented 10 months ago

Thanks for bringing this up; I agree that it makes sense to allow grouped heads. As far as I see it's all possible with views, as long as the ops support the extra rank.

However, it might be a little difficult to do right now, because we're in the process of changing the memory layout in NATTEN from [batch, heads, *, dim] to [batch, *, heads, dim]. It'll avoid the extra permutations around the op, and there's really no point in permuting heads unless we're strictly doing a BMM, which hasn't really been the case since the first cuda kernel.

Multi-query is the easiest, we'd just add the extra head dimension and set its stride to 0 indicating no movement. All of our kernels could support this easily if we relax the condition on inputs being contiguous.

Grouped query should also be easy if done similarly; we'd just add an extra dim of length 1 and stride 0.

So in summary, yes, I think it should be possible without explicitly repeating. It just might require me to think about the problem a little bit because it'll be a somewhat significant change in the API to either allow higher-rank inputs, or always expect them (because the grouped query attention format will contain both standard multi-headed and multi-query).

Do you happen to need this urgently, or can it wait for a while?

Birch-san commented 10 months ago

thanks for thinking so deeply about this 🙂 avoiding permutations sounds good too; that seems more important for now!

GQA can wait for a while. it's something that would be nice to put into an HDiT one day, especially into a bigger one. but haven't made any specific plans for how/when we would approach scaling up the arch.

MQA is known to be not as competitive on quality, so I think it wouldn't be suited to typical use-cases. could be useful for benchmarks, like competing on "fewest params" or "fastest forward pass" or something.

alihassanijr commented 10 months ago

Of course; and congratulations on the paper! It's a very interesting work.

We're planning on moving to the new layout soon in order to unify the API for our upcoming fused kernels, so stay tuned for that.

As for GQA and MQA, I'll see if I can add support for them while switching the memory layout, because the two are somewhat related.

Birch-san commented 8 months ago

glad to hear that the the fused kernels' new memory layout may provide a path to supporting GQA.

and congrats on shipping FNA forward!
how's FNA backward coming along?

wondering about using NATTEN in the pretraining of an LLM.
specifically:

is a featureset like this close enough we could wait for it?

alihassanijr commented 8 months ago

Thank you!

I am working on FNA backward; it's a little difficult to estimate at this point, but I'm hoping it won't be more than a few weeks since I don't anticipate miscellaneous school work getting in the way for a while.

Naive kernels support all the new features introduced by FNA: causal masking, varying axis parameters (2-D and 3-D). (However there's a minor performance regression as a result of that, because previously a lot of those parameters were evaluated at compile time because of the strict limitations.)

BMM-style ops (naive and GEMM) both support registers, and I think I can add GQA support to all ops (BMM-style and fused) pretty quickly if that's blocking you from experimenting. But training with FNA would still depend on when I can wrap up FNA backward.

However, thinking about GQA with and without the extra repeat, I'm not too sure whether you'd see a noticeable improvement in speed, if any. You definitely will save on the extra latency from the repeat op, but at the same time the threads executing attention kernels would have more conflicted memory accesses, which may or may not exceed the repeat latency. So I wouldn't really get my hopes up that getting rid of the repeat op would help much. But again this is all just guesswork at this point. We can implement and observe the effect for sure.

The only thing I'm unsure about is supporting registers or zero attention in FNA. Forward pass is mostly trivial in both cases, but backward pass would be a little tricky. I have a naive solution to supporting registers in my head, but I'm not 100% on whether it'll work or whether we can implement it quickly enough.

Birch-san commented 8 months ago

thanks for the quick reply!

it's good to hear your intuituion that resorting to a repeat may not mean leaving much perf on the table.

how naïve are the naïve kernels? are they doing all the same IO+computation of full self-attention, just applying a complex mask? or is it better than that (avoids quadratic complexity)?

from the readme, it sounds like GEMM kernels don't support causal masking. if this is the case, then we'd not be able to use them for LLM training, even if GQA were added. however, the GQA would still be useful for training vision models.

in terms of whether to jump on this in order to unblock experimentation: GEMM GQA could certainly be useful for vision models (on which I'd hope to begin experimentation over the coming weeks, and probably scale-up is far enough away that FNA could become available). for language models though we're exploring options for a fused kernel solution.

alihassanijr commented 8 months ago

Of course; happy to help.

On naive kernels; the gist of it is, no, none of our kernels really do self attention + masking, but that doesn't mean they will be faster than self attention. It depends on the problem size, element type, and hardware (more on this below).

Yes unfortunately GEMM kernels don't support causal masking yet. But if you're particularly interested in 1-D Causal NA, the good news is that it is equivalent to 1-D Causal sliding window attention. As far as I know both FAv2 and xFormers' FMHA support 1-D sliding window attention with causal masking, so you probably can use them for now, but again only when your token space is 1-D, and only when you're doing causal masking. The major difference between neighborhood and sliding window attention is in how corner cases are handled; neighborhood attention attempts to force every token to attend to as many tokens as it can, bounded by neighborhood size. This attempt can only happen in non-causal scenarios, which is why causal NA and causal sliding window effectively have the same pattern.

But I'll re-emphasize that to my knowledge, there's no alternative to NATTEN and NA when it comes to higher-rank spaces like 2-D and 3-D. Happy to get into exactly what the differences are and why higher rank is not as trivial as 1-D in terms of implementation.

As for GQA, I'll definitely try and add it to BMM-style this week to unblock your vision experiments. And I hope to be able to finish FNA backward pretty soon as well.

More on naive vs GEMM/FNA

So naive kernels are only called that to indicate that they are the least performance-optimized kernels. This means the code is mostly straightforward and readable, and there's been no effort to make it more "hardware friendly". Naive kernels are generally good solutions when you either need a new feature really fast, or when you need a reference implementation to evaluate the correctness of performance-optimized implementations.

When it comes to most code, computation is usually not the main bottleneck, but rather making full use of the memory hierarchy and other engineering concepts that make the most use of your hardware. To give an example, writing a matrix multiplication in C (or any high level language for that matter) as three nested for loops is a naive implementation, and won't make full use of even a standard non-parallel processor.

When you use a native operation in PyTorch or any other DL framework of your choice, they pick the best implementation available in relevant linear algebra libraries (BLAS, cuBLAS, and the like) or DL libraries (i.e. cuDNN), and very rarely may target naive implementations in PyTorch itself (if it's a really specific problem on very specific hardware).

Because of all of this, naive implementations of any concept that's considered computationally intensive will generally be underutilizing accelerators, particularly more modern ones. Aside from that, there's new hardware components once in a while that require very specific arrangements of data in their local memory / registers and very specific instructions to work; like Tensor Cores (since Volta), and the TMA (since Hopper).

So none of the kernels in NATTEN do the full self attention (unless your window size is approximately the same as your window size). The difference is mostly in how well the code is performance optimized.

GEMM kernels are much better alternatives than naive in many cases, but they don't do that much better in some cases with FP16/BF16, because the gather / scattering of attention weights breaking the asynchronous flow of the GEMM mainloop. But this is architecture specific. There's more details on this in the preprint.

Birch-san commented 8 months ago

ahh, gotcha. so naïve isn't a masked full-self-attention, but nor is it an optimized matmul, so it's not necessarily faster. thanks for explaining.

yes, FA2/xformers will probably do the trick. we could (gulp) try modifying one of those to support zero-attn. or "simulate" BOS/registers by, uh, inserting them into the sequence in multiple locations (without positional embedding).

do you have any expectation of how the eventual performance of fused NATTEN (1D, causal) could compare to FA2? I think if ends up with similar performance, but with better support for registers, then that might be a reason to prefer it once ready.

alihassanijr commented 8 months ago

Yes that's correct.

So only speaking for 1-D, FNA and FMHA will have similar performance, because FNA is largely based on FMHA. On the other hand, FMHA has usually been inferior to FAv2 because there's an additional partitioning going on in FAv2 that FMHA doesn't do in the forward pass. This might not matter to many, but FAv2 only targets Ampere (and supports Ampere and newer architectures only), whereas FMHA supports all architectures since Maxwell (and even less important; FAv2 is only implemented for half precision). These were all part of the reason why we implemented FNA on top of FMHA.

You do bring up a fair point; registers and global attention is trivial with FAv2/FMHA. Registers + local attention might not be.

I guess in this case, if zero attention is something that will do the trick (meaning you'd only need to modify the softmax operator), making that change in either FAv2/FMHA would probably be trivial, as long as it doesn't make the backprop more complicated (can't really speak to that right now). I could definitely point you in the right direction for that, but probably better so in FMHA since that's the implementation I'm most familiar with. But if arbitrarily sized registers are what you need, then yeah it might take the same amount of work to add them to either of FNA or FMHA.

alihassanijr commented 8 months ago

One other idea (and I might be totally off on this; I need to read the online softmax paper again) is to see whether we can just completely split the query-register cross attention into a different branch, and aggregate the two resulting outputs together into the final one. It should be agnostic to the "core" attention and how it's implemented (fused or unfused) and it should be at least in theory exactly the same as the original.

Birch-san commented 8 months ago

half-precision and "Ampere and newer" are both fine for our use-case, so if FAv2 is faster then I guess we'd prefer that.

Tri Dao already hinted at where to change the forward pass to add support for zero attn:
https://github.com/Dao-AILab/flash-attention/issues/616#issuecomment-1775912492

Probably no change is needed for the backwards pass, if my maths is right. I walked through the softmax derivative, adding 1 into the denominator. I reached the same result in both cases (i=j and i≠j): a derivative defined in terms of the result from the forward pass.

I think you're right that you could take advantage of online softmax to handle query-register xattn out-of-band. I implemented an online softmax as part of my memory-efficient attention in pure PyTorch, and yeah it looks like that's "just another chunk" to accumulate.

equally, I wonder how hard it'd be to add registers to FAv2. I figure I'd just need to find where it does online softmax over QK chunks, then incorporate additional QK chunks from a register tensor. and something similar for attn_probs @ V. and likewise for the backwards pass. but maybe this is a lot harder said than done in CUDA.

alihassanijr commented 8 months ago

half-precision and "Ampere and newer" are both fine for our use-case, so if FAv2 is faster then I guess we'd prefer that.

Tri Dao already hinted at where to change the forward pass to add support for zero attn: https://github.com/Dao-AILab/flash-attention/issues/616#issuecomment-1775912492

Probably no change is needed for the backwards pass, if my maths is right. I walked through the softmax derivative, adding 1 into the denominator. I reached the same result in both cases (i=j and i≠j): a derivative defined in terms of the result from the forward pass.

I can't really comment on the backward pass at the moment; let me take a closer look and get back to you.

I think you're right that you could take advantage of online softmax to handle query-register xattn out-of-band. I implemented an online softmax as part of my memory-efficient attention in pure PyTorch, and yeah it looks like that's "just another chunk" to accumulate.

Yeah I think you should be able to use online softmax, but of course you'd need the attention kernel (or softmax if unfused) to store the row maximums and sums to global memory. Once you have those, you should be able to both implement the query-register xattn in torch or as a standalone kernel (triton might be a good choice, but if it's always a single register token, even a naive kernel might be okay.) A better approach is to modify the attention kernel to read row max and sum values as well, that way you can just call the attention kernel twice: once with the original KV, and once more with the register KV. (FAv2 might already support this -- but I'm not sure.)

equally, I wonder how hard it'd be to add registers to FAv2. I figure I'd just need to find where it does online softmax over QK chunks, then incorporate additional QK chunks from a register tensor. and something similar for attn_probs @ V. and likewise for the backwards pass. but maybe this is a lot harder said than done in CUDA.

Well that's where it gets a little complicated. I would guess having the kernel read m_i and l_i the first time, and store the values to global memory is probably not a worse solution than having everything done within the same kernel launch, and it probably requires minimal effort to implement. Adding an extra KV that ignores the sliding window parameters will probably just require more work to implement, but the two kernel launches will be sequential, so you might feel more latency that way. But if you're using sliding window, I'm assuming you're dealing with a relatively large problem size, which means one extra kernel launch with just one tile to work with (assuming very few register tokens) will be negligible.

As for modification, if you're not familiar with the FAv2 codebase and CUTLASS, I would imagine the former would be a lot easier and quicker to implement than the latter.

alihassanijr commented 7 months ago

@Birch-san I looked into GQA/MQA further, and I think explicitly repeating the tensor might be the best solution for training. The backward pass can be handled in two different ways: 1. keep the current behavior, and reduce the gradient tensors for K and V, or 2. modify the kernel to output the same tensor layout as the inputs.

The first will obviously incur the extra latency from the addition reduce op, and the second will hit a race condition on "shared heads", which means we'd either have to write a completely different kernel that handles this case differently (i.e. change the threadblock layout to avoid the race condition, or use scratch memory with atomic locks), OR just simply avoid the issue by using atomic accumulation, which will be slower and non-deterministic.

It's near impossible to tell which will be the best performing one, but if I had to make a guess, I wouldn't expect the best possible solution from 2 to be that much faster than 1.

If you end up using torch.compile, the graph optimizations might already reduce some of the additional latency from repeat and the backward pass for that, but that's hard to say (and I'll try my best to fix NATTEN's compatibility with torch.compile so it doesn't force torch to break graphs.)

That said, I still think your case for inference is valid and relatively easy to handle, so I'm happy to try and add support for that to all NATTEN ops at some point.

What do you think?

Birch-san commented 7 months ago

hmm it sounds like the "custom kernel" solutions will add complexity and require a fair bit of engineering. without even being clear whether it's worth it.

I dunno how expensive the repeat is, nor what its price is after a torch compile. but I don't have any data to say it's a bottleneck that's worth putting a large amount of engineering into.

I think better to spend the time on sure bets like continuing to fuse the backward pass.

I think GQA is mostly intended to speed up inference anyway, so if you're saying there's an easy path to make it fast/supported for inference only: that sounds worth a shot.

I think there's still plenty of weeks before I'll be ready to use this for anything, so no need to hurry on my account.

alihassanijr commented 7 months ago

Yeah I think forward pass will be easily manageable, and I'm pretty sure I've already seen it handled in some other inference engine. And it makes sense, the only conflict will be on read and not write, which can only be very slightly slower a case that's already had it repeated, but my guess is it'll almost always be faster than doing the repeat explicitly. So yeah I'll try and have NATTEN support it pretty soon.

But yeah as far as I can tell, there's no easy solution to training without repeat that's guaranteed to be deterministic and faster. Or even just faster with tolerance for non-determinism (which isn't probably too uncommon for backprop.)