dvlab-research / LongLoRA

Code and documents of LongLoRA and LongAlpaca (ICLR 2024 Oral)
http://arxiv.org/abs/2309.12307
Apache License 2.0
2.63k stars 273 forks source link

Help to confirm understanding of forward_flashattn #126

Closed weicheng113 closed 12 months ago

weicheng113 commented 1 year ago

Dear Authors and @yukang2017 ,

Thanks for the amazing work. I am trying to understand the following: https://github.com/dvlab-research/LongLoRA/blob/2a33f37543038877c70e9a625a61dc72a71621d0/llama_attn_replace_sft.py#L24

Could you please help me to confirm if my following understanding is correct?

output_unpad = flash_attn_varlen_qkvpacked_func(
    # 2 from bsz*2*seq_len is the 2 half-heads for non-shift and shift groups.
    qkv=x_unpad[bsz*2*seq_len, num_heads//2, head_dim],
    cu_seqlens=[0, 2048, 4096, 6144, 8192, 9216, 11264, 13312, 15360, 16384], 
    max_seqlen=group_size=2028, 
    ...)

from https://github.com/dvlab-research/LongLoRA/blob/2a33f37543038877c70e9a625a61dc72a71621d0/llama_attn_replace_sft.py#L123

The result of the function call, output_unpad, is the attention weighted values. I was trying to look into the source code of flash_attn_varlen_qkvpacked_func to understand it, but I was blocked by the native code call flash_attn_cuda.varlen_fwd().

Could you please help confirm my following understanding. My guess is that flash_attn_varlen_qkvpacked_func is able to calculate attention weighted values for variant length sequences(local groups). The locations of each local group attention is given by cu_seqlens and local group attentions are calculated for [0, 2048], [2048, 4096], ..., [8192, 9216], ..., [15360, 16384] in this example. Note that some of the local groups(for example [8192, 9216], [15360, 16384]) may be less than max_seqlen(group_size), but flash_attn_varlen_qkvpacked_func can calculate for these variant lengths. Attentions along the sequence are calculated with causal attention mask and calculated separately for first half heads(for non-shifted groups) and second half heads(for shifted groups). Therefore, there is no information leak from shifted groups into non-shifted groups(There is potential information leak in forward_noflashattn, as mentioned in #90. I think information leak might be prevented if rolling back is done before attn_weights calculation, but it is not important, as flash attention is always used for real training). A latter far away token in a different local group is able to absorb the information from former token in another local group because of shifted group. And this ability is expected to be enlarged to cover longer distance over the layers. A former token will never be able to get information from a latter token because of causal attention mask. The attention weighted value is calculated separately for non-shifted and shifted local groups but the information will be considered together at FeedForward.

Thanks in advance for your time and help, Cheng

yukang2017 commented 12 months ago

Hi,

I think your understanding is correct. Thanks for your question.

Regards, Yukang Chen

weicheng113 commented 12 months ago

Thanks Yukang for your time and answer.