Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.3k stars 1.2k forks source link

Backprop through LSE #889

Open abf149 opened 5 months ago

abf149 commented 5 months ago

Hello, I would like to PR a new feature, which allows FlashAttention to support backpropagation through log-sum-exponent (LSE).

In other words, the mha_bwd function signature (in pseudocode) is currently

mha_bwd(dout,q,k,v,out,softmax_lse,dq,dk,dv, //etc. )

but would become

mha_bwd(dout,dsoftmax_lse,q,k,v,out,softmax_lse,dq,dk,dv, //etc. )

Motivation: There is a scenario where I am training an LM & I want a penalty term against large LSE values in my loss function, to prevent attention scores from getting too large. So my loss function (pseudocode) is something like

cross_entropy(LM output) + lambda*sum(LSE over all attention layers)

where cross_entropy(LM output) is the usual LM loss, lambda is a tuning parameter, and sum(LSE over all attention layers) is a term that tries to minimize the LSE of all attention layers.

The first term introduces a dout error which must be backpropagated through each attention layer. The second term introduces a dsoftmax_lse which must be backpropagated through each attention layer.

The problem is that there is currently no way to feed dsoftmax_lse to the backward kernel, only dout.

A year or so ago (probably in the FlashAttention1 era) I implemented mha_bwd() with support for dsoftmax_lse, as part of a project.

I think it could be valuable if I merged my changes into FlashAttention2 and then made a PR, so that other people may utilize this capability to backprop through LSE.

Do you have guidelines for new contributors to this repo?

311dada commented 3 months ago

Look forward to this PR!