Open albert-cwkuo opened 1 month ago
Backprop on softmax_lse is not supported. Feel free to work on it if you need it. You just have to work out the gradient and then implement it. I suspect it's not too bad.
Thanks for your reply :). I am new to cuda and fuse kernel, etc. Do you mind pointing me to which part I should dig into and an abstract idea of how this could be implemented? Thanks a lot.
It depends on what the gradient looks like. What's the gradient for softmax_lse?
Sorry I don't quite get what you mean :sweat_smile:.
What I expect is that given the softmax_lse of q and k of shape, let's say N x nhead x Sq x Sk, I am able to compute the gradient of the same shape for each element of softmax_lse. Afterward, the gradient of each element in softmax_lse can be further propagated through q and k. Only first order gradient is needed in my case.
Well you need to work out how to compute the gradient mathematically (e.g see FlashAttention paper appendix B2) before implementing it.
Thanks a lot @tridao for the reference! Took me some time to derive the gradient of LSE w.r.t. q & k analytically:
For dq: $\frac{\partial}{\partial q}LSE(q k^T, \ \text{axis=1}) = \text{softmax}(q k^T,\ \text{axis=1})k$
For dk: $\frac{\partial}{\partial k}LSE(q k^T, \ \text{axis=1}) = \text{softmax}(q k^T,\ \text{axis=1})^T q$
Any advice for how and where should I plug this into the cuda code such that the returned softmax lse supports backprop?
Bwd pass code is here: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h you can follow acc_dk and acc_dq to see how it's being computed right now.
Hi @tridao ,
I'd like to use the value of softmax_lse in my model and back-propagate gradient through it. However, I do see another discussion saying that it is not taken into account during the backward pass.
Does newer version support back-prop for softmax_lse? If not, how easy/difficult will it be to modify the cuda code to support that? Thanks in advance for any advise!