zhuzilin / ring-flash-attention

Ring attention implementation with flash attention
MIT License
575 stars 45 forks source link

Got error in ZigZagRingFlashAttnVarlenFunc #46

Closed ThisisBillhe closed 6 days ago

ThisisBillhe commented 2 months ago
  1. It seems the batch dimension will be disappeared after _upad_input function (this function is usually copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input). Then the block_lse obtained from L118 in zigzag_ring_flash_attn_varlen.py only has 2 dimensions (num_head and seq_len). It will cause error in the flatten_varlen_lse function (L120 in zigzag_ring_flash_attn_varlen.py), where the block_lse are required to have three dimensions.
  2. An illegal memory access error will be reported in the 'else' branch in L135 of zigzag_ring_flash_attn_varlen.py. I can not even print the half_cu_seqlens or cu_seqlens tensor before flatten_varlen_lse function:
    
    File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 140, in zigzag_ring_flash_attn_varlen_forward
    print(cu_seqlens)
    File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor.py", line 431, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
    File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 664, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
    File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 595, in _str_intern
    tensor_str = _tensor_str(self, indent)
    File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 347, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
    File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 133, in __init__
    value_str = f"{value}"
    File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor.py", line 933, in __format__
    return self.item().__format__(format_spec)
    RuntimeError: CUDA error: an illegal memory access was encountered
    '''
zhuzilin commented 2 months ago

fixed in #47. the reason of the bug is: https://github.com/zhuzilin/ring-flash-attention/issues/44#issuecomment-2330991965

ThisisBillhe commented 1 month ago

Thanks for your reply! I have tried your latest commit and sadly it did not run well in my case. The program will get stuck. I think the reason is the attention mask for _flash_attn_varlen_forward is different across ranks. Do you possibly know how to address this?

ThisisBillhe commented 1 month ago

Perhaps we should send cu_seqlens_k and max_seqlen_in_batch_k along with k and v to other ranks.

zhuzilin commented 1 month ago

hmm... are you using the lastest main branch of the repo? I've just given it another try, it should works with:

torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py

And as for attention mask is different across ranks, it is by design.