The result of the function call, output_unpad, is the attention weighted values. I was trying to look into the source code of flash_attn_varlen_qkvpacked_func to understand it, but I was blocked by the native code call flash_attn_cuda.varlen_fwd().
Could you please help confirm my following understanding. My guess is that flash_attn_varlen_qkvpacked_func is able to calculate attention weighted values for variant length sequences(local groups). The locations of each local group attention is given by cu_seqlens and local group attentions are calculated for [0, 2048], [2048, 4096], ..., [8192, 9216], ..., [15360, 16384] in this example. Note that some of the local groups(for example [8192, 9216], [15360, 16384]) may be less than max_seqlen(group_size), but flash_attn_varlen_qkvpacked_func can calculate for these variant lengths. Attentions along the sequence are calculated with causal attention mask and calculated separately for first half heads(for non-shifted groups) and second half heads(for shifted groups). Therefore, there is no information leak from shifted groups into non-shifted groups(There is potential information leak in forward_noflashattn, as mentioned in #90. I think information leak might be prevented if rolling back is done before attn_weights calculation, but it is not important, as flash attention is always used for real training). A latter far away token in a different local group is able to absorb the information from former token in another local group because of shifted group. And this ability is expected to be enlarged to cover longer distance over the layers. A former token will never be able to get information from a latter token because of causal attention mask. The attention weighted value is calculated separately for non-shifted and shifted local groups but the information will be considered together at FeedForward.
Dear Authors and @yukang2017 ,
Thanks for the amazing work. I am trying to understand the following: https://github.com/dvlab-research/LongLoRA/blob/2a33f37543038877c70e9a625a61dc72a71621d0/llama_attn_replace_sft.py#L24
Could you please help me to confirm if my following understanding is correct?
from https://github.com/dvlab-research/LongLoRA/blob/2a33f37543038877c70e9a625a61dc72a71621d0/llama_attn_replace_sft.py#L123
The result of the function call, output_unpad, is the attention weighted values. I was trying to look into the source code of flash_attn_varlen_qkvpacked_func to understand it, but I was blocked by the native code call flash_attn_cuda.varlen_fwd().
Could you please help confirm my following understanding. My guess is that flash_attn_varlen_qkvpacked_func is able to calculate attention weighted values for variant length sequences(local groups). The locations of each local group attention is given by cu_seqlens and local group attentions are calculated for [0, 2048], [2048, 4096], ..., [8192, 9216], ..., [15360, 16384] in this example. Note that some of the local groups(for example [8192, 9216], [15360, 16384]) may be less than max_seqlen(group_size), but flash_attn_varlen_qkvpacked_func can calculate for these variant lengths. Attentions along the sequence are calculated with causal attention mask and calculated separately for first half heads(for non-shifted groups) and second half heads(for shifted groups). Therefore, there is no information leak from shifted groups into non-shifted groups(There is potential information leak in forward_noflashattn, as mentioned in #90. I think information leak might be prevented if rolling back is done before attn_weights calculation, but it is not important, as flash attention is always used for real training). A latter far away token in a different local group is able to absorb the information from former token in another local group because of shifted group. And this ability is expected to be enlarged to cover longer distance over the layers. A former token will never be able to get information from a latter token because of causal attention mask. The attention weighted value is calculated separately for non-shifted and shifted local groups but the information will be considered together at FeedForward.
Thanks in advance for your time and help, Cheng