microsoft / unilm

Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities
https://aka.ms/GeneralAI
MIT License
20.3k stars 2.56k forks source link

How to merge a KV cache into the Diff Attention? #1661

Open ZetangForward opened 4 days ago

ZetangForward commented 4 days ago

I notice the authors only provide the vanilla code in the Diff Attention Repo. However, in the paper, the authors also report the performance in long-context scenarios. Vanilla implementation of Diff Attention can not support 64K context length on 80GB GPU. I wonder how authors achieve long context inference. Is there a KV cache version of Diff Attention?

YTianZHU commented 5 hours ago

@ZetangForward hi, in long-context scenarios, we use flash decoding, and it can support 64K length inference on 80GB GPU.

If you use Diff-Transformer/multihead_flashdiff_1, you can refer to https://aka.ms/flash-diff for flash decoding support.

If you use Diff-Transformer/multihead_flashdiff_2, you can refer to official flash decoding at https://github.com/Dao-AILab/flash-attention

ZetangForward commented 5 hours ago

@ZetangForward hi, in long-context scenarios, we use flash decoding, and it can support 64K length inference on 80GB GPU.

If you use Diff-Transformer/multihead_flashdiff_1, you can refer to https://aka.ms/flash-diff for flash decoding support.

If you use Diff-Transformer/multihead_flashdiff_2, you can refer to official flash decoding at https://github.com/Dao-AILab/flash-attention

ok, thx. BTW, I want ask an additional question that is irrevelent with the code.

I found that there is no clear objective function to constrain the ''differential heads'' in the paper, but the ability to ''eliminate noise" is automatically learned through the designed gate mechanism. I am curious if it is possible to explain intuitively why vanilla's training objective function (i.e. Next Token Prediction+CE Loss) can eliminate differences between a set of heads (two heads)? Does this phenomenon occur in untrained models (e.g., Llama3)?

@YTianZHU

YTianZHU commented 4 hours ago

@ZetangForward Although the paired heads are independent in forward, they can perceive each other in backward. The two heads are fused together after differential attention, therefore there is information of both heads in the gradients of weights (Wq, Wk). The gradients of these weights guide the two heads to learn how to project the input according to each other.

A model without any training (with randomly initialized weights) can't have this ability.