Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.36k stars 1.22k forks source link

How to obtain differentiable softmax_lse #1137

Open albert-cwkuo opened 1 month ago

albert-cwkuo commented 1 month ago

Hi @tridao ,

I'd like to use the value of softmax_lse in my model and back-propagate gradient through it. However, I do see another discussion saying that it is not taken into account during the backward pass.

Does newer version support back-prop for softmax_lse? If not, how easy/difficult will it be to modify the cuda code to support that? Thanks in advance for any advise!

tridao commented 1 month ago

Backprop on softmax_lse is not supported. Feel free to work on it if you need it. You just have to work out the gradient and then implement it. I suspect it's not too bad.

albert-cwkuo commented 1 month ago

Thanks for your reply :). I am new to cuda and fuse kernel, etc. Do you mind pointing me to which part I should dig into and an abstract idea of how this could be implemented? Thanks a lot.

tridao commented 1 month ago

It depends on what the gradient looks like. What's the gradient for softmax_lse?

albert-cwkuo commented 1 month ago

Sorry I don't quite get what you mean :sweat_smile:.

What I expect is that given the softmax_lse of q and k of shape, let's say N x nhead x Sq x Sk, I am able to compute the gradient of the same shape for each element of softmax_lse. Afterward, the gradient of each element in softmax_lse can be further propagated through q and k. Only first order gradient is needed in my case.

tridao commented 1 month ago

Well you need to work out how to compute the gradient mathematically (e.g see FlashAttention paper appendix B2) before implementing it.

albert-cwkuo commented 1 month ago

Thanks a lot @tridao for the reference! Took me some time to derive the gradient of LSE w.r.t. q & k analytically:

For dq: $\frac{\partial}{\partial q}LSE(q k^T, \ \text{axis=1}) = \text{softmax}(q k^T,\ \text{axis=1})k$

For dk: $\frac{\partial}{\partial k}LSE(q k^T, \ \text{axis=1}) = \text{softmax}(q k^T,\ \text{axis=1})^T q$

Any advice for how and where should I plug this into the cuda code such that the returned softmax lse supports backprop?

tridao commented 1 month ago

Bwd pass code is here: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h you can follow acc_dk and acc_dq to see how it's being computed right now.