Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.29k stars 1.34k forks source link

Add QuietAttention #616

Open janEbert opened 1 year ago

janEbert commented 1 year ago

The blog post titled Attention Is Off By One proposes adding 1 to the denominator of the softmax operation in order to fix quantization instabilities, by allowing attention heads not to select a key for a query. They call the resulting softmax function "softmax1" or "ghost softmax", and the attention function implementing this softmax1 "QuietAttention".

The implementation should be super simple; add a boolean to the function signature and add 1 to the softmax accumulator before scaling.

Would it be possible to add this feature? If it is more work than I am imagining, I can also invest my time; however, if at least the correct places for the increment could be identified, that would help a lot.

tridao commented 1 year ago

What about just adding a key that's all zero before calling attention?

janEbert commented 1 year ago

Definitely; but then I'd (1) destroy that sweet bit alignment, and (2) introduce an extra copy. Wouldn't that be quite detrimental to performance?

janEbert commented 1 year ago

I'd like to apply this to a large-scale training, so keeping optimal performance is important to me.

tridao commented 1 year ago

I'd like to apply this to a large-scale training, so keeping optimal performance is important to me.

I'm curious, what kind of improvement do you see at smaller scale?

iiLaurens commented 1 year ago

I'd like to apply this to a large-scale training, so keeping optimal performance is important to me.

I'm curious, what kind of improvement do you see at smaller scale?

I can't answer it exactly but empirical evidence is starting to appear that suggests LLMs turn the first token in an attention sink that has no semantic meaning whatsoever beyond the first layer. It effectively has become the "zero key" that you suggest. See also the StreamingLLM paper/project. Quiet Attention could replace the need for an attention sink.

The StreamingLLM approach is basically an ad-hoc hack to improve perplexity beyond the context limit by keeping the first token around. I suspect that an LLM trained with QuietAttention might not need this hack anymore and scale beyond the trained context limit out-of-the-box.

Because it is only a hypothesis I don't think it warrants a change to this kernel yet, but a temporary branch or fork such that the GPU poor can start experimenting more efficiently might be nice!

janEbert commented 1 year ago

Thank you @iiLaurens, this summarizes the blog post quite well! In addition, I'd like to leave this empirical evaluation of the suggestion on a small scale: Blog post and the corresponding Colab.

To be sure this is clear, QuietAttention in and of itself is not about improving model performance after pre-training. Its purpose is to avoid the creation of large values in this attention sink, which make quantization difficult due to these large outliers. So the benefits are mostly supposed to be found in the quantized model by making quantization more exact in the non-outlier region. There could be benefits from avoiding large post-softmax values during training as well, but this is not the reason why QuietAttention was proposed.

@tridao, if you'd be so kind to point me to the place where to add the extra 1, I'd gladly research in a fork like @iiLaurens suggested and create a PR later (if you'd like the change to be in here). I'm really just worried about not inserting the addition in all the necessary places in the kernel(s?). :)

tridao commented 1 year ago

This multiplies the output by the inverse of the denominator, so you can add sth to the denominator there. I think the output was already multiplied by exp(-max) before this normalization, so we're multiplying by sum_j exp(x_j - max). So I think we want to add exp(-max) here (and not 1). The max variable stores the max before multiply by 1/sqrt(d) so I'm guessing it's exp(-max * params.softmax_scaling). I'm not positive because I haven't tried.

In the backward pass, I think there's some step that assumes that things sum to 1, which will no longer be the case if we have softmax1. I'll need to think about it more.

janEbert commented 1 year ago

Thank you so much, Tri! This is unbelievably helpful already. :) When I was looking through the code, I expected the denominator to be in one of the LSE variables, so you already prevented me looking at the wrong stuff. Also it's great to know that there's only one place (or two, if we include the backward pass) to modify.

I'll fool around with the implementation and get back to you!

ad8e commented 7 months ago

Adding a zero key at the beginning works, and the speed penalty is <2%. However, it doesn't work for sliding window attention, because the key falls out of the window.

Birch-san commented 7 months ago

I think the backwards pass may not require any changes. I followed these steps to compute the softmax derivative, with 1+∑ substituted into the softmax denominator, and still ended up with the same result (for both i=j and i≠j): softmax's derivative is defined in terms of its own output.

though if there's implementation details relying on the result's summing to 1: I guess I haven't considered that.

janEbert commented 7 months ago

Hey, thanks @ad8e for the feedback. In my mind, this is too much impact for something that could be as simple as an addition of 1. But of course it all depends on the use case; for me, even <2% compute add up to a lot over the course of pre-training.

@Birch-san thank you for sharing. I haven't been able to modify the code so it passes QuietAttention tests myself, even after trying out various forms for the addition of 1 in the denominator, such as log2 (which seemed to be the required form from my personal inference) or ln, etc.