Open ZetangForward opened 4 days ago
@ZetangForward hi, in long-context scenarios, we use flash decoding, and it can support 64K length inference on 80GB GPU.
If you use Diff-Transformer/multihead_flashdiff_1, you can refer to https://aka.ms/flash-diff for flash decoding support.
If you use Diff-Transformer/multihead_flashdiff_2, you can refer to official flash decoding at https://github.com/Dao-AILab/flash-attention
@ZetangForward hi, in long-context scenarios, we use flash decoding, and it can support 64K length inference on 80GB GPU.
If you use Diff-Transformer/multihead_flashdiff_1, you can refer to https://aka.ms/flash-diff for flash decoding support.
If you use Diff-Transformer/multihead_flashdiff_2, you can refer to official flash decoding at https://github.com/Dao-AILab/flash-attention
ok, thx. BTW, I want ask an additional question that is irrevelent with the code.
I found that there is no clear objective function to constrain the ''differential heads'' in the paper, but the ability to ''eliminate noise" is automatically learned through the designed gate mechanism. I am curious if it is possible to explain intuitively why vanilla's training objective function (i.e. Next Token Prediction+CE Loss) can eliminate differences between a set of heads (two heads)? Does this phenomenon occur in untrained models (e.g., Llama3)?
@YTianZHU
@ZetangForward Although the paired heads are independent in forward, they can perceive each other in backward. The two heads are fused together after differential attention, therefore there is information of both heads in the gradients of weights (Wq, Wk). The gradients of these weights guide the two heads to learn how to project the input according to each other.
A model without any training (with randomly initialized weights) can't have this ability.
I notice the authors only provide the vanilla code in the Diff Attention Repo. However, in the paper, the authors also report the performance in long-context scenarios. Vanilla implementation of Diff Attention can not support 64K context length on 80GB GPU. I wonder how authors achieve long context inference. Is there a KV cache version of Diff Attention?