Closed ThisisBillhe closed 6 days ago
fixed in #47. the reason of the bug is: https://github.com/zhuzilin/ring-flash-attention/issues/44#issuecomment-2330991965
Thanks for your reply! I have tried your latest commit and sadly it did not run well in my case. The program will get stuck. I think the reason is the attention mask for _flash_attn_varlen_forward is different across ranks. Do you possibly know how to address this?
Perhaps we should send cu_seqlens_k and max_seqlen_in_batch_k along with k and v to other ranks.
hmm... are you using the lastest main branch of the repo? I've just given it another try, it should works with:
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
And as for attention mask is different across ranks, it is by design.