zhuzilin / ring-flash-attention

Ring attention implementation with flash attention
MIT License
571 stars 45 forks source link

run the code has error #44

Open lambda7xx opened 3 months ago

lambda7xx commented 3 months ago

I use 4 gpus to run the code. my command is

torchrun --nproc_per_node 4 test/test_ring_flash_attn_varlen_func.py 

my error is

rank1]: Traceback (most recent call last):
[rank1]:   File "/home/xxxx/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py", line 126, in <module>
[rank1]:     lse_list = extract_lse(lse, cu_seqlens)
[rank1]:   File "/home/xxxx/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py", line 57, in extract_lse
[rank1]:     value = lse[i, :, : end - start]
[rank1]: IndexError: too many indices for tensor of dimension 2
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/xxxx/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py", line 126, in <module>
[rank0]:     lse_list = extract_lse(lse, cu_seqlens)
[rank0]:   File "/home/xxxx/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py", line 57, in extract_lse
[rank0]:     value = lse[i, :, : end - start]
[rank0]: IndexError: too many indices for tensor of dimension 2
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/xxxx/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py", line 126, in <module>
[rank2]:     lse_list = extract_lse(lse, cu_seqlens)
[rank2]:   File "/home/xxxx/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py", line 57, in extract_lse
[rank2]:     value = lse[i, :, : end - start]
[rank2]: IndexError: too many indices for tensor of dimension 2
lambda7xx commented 3 months ago

the command

torchrun --nproc_per_node 4 test/test_ring_flash_attn_func.py

the error is

  File "/home/xxxxs/ring-flash-attention/test/test_ring_flash_attn_func.py", line 97, in <module>
[rank0]:     ring_out, ring_lse, _ = fn(
[rank0]:   File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 214, in ring_flash_attn_qkvpacked_func
[rank0]:   File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 159, in forward
[rank0]:   File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 33, in ring_flash_attn_forward
[rank0]: TypeError: _flash_attn_forward() missing 1 required positional argument: 'softcap'
[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/xxxxs/ring-flash-attention/test/test_ring_flash_attn_func.py", line 97, in <module>
[rank3]:     ring_out, ring_lse, _ = fn(
[rank3]:   File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 214, in ring_flash_attn_qkvpacked_func
[rank3]:   File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
[rank3]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank3]:   File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 159, in forward
[rank3]:   File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 33, in ring_flash_attn_forward
[rank3]: TypeError: _flash_attn_forward() missing 1 required positional argument: 'softcap'
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/xxxxs/ring-flash-attention/test/test_ring_flash_attn_func.py", line 97, 
aniki-ly commented 2 months ago

same issue

zhuzilin commented 1 month ago

@lambda7xx @aniki-ly sorry for the late reply. The reason for the error is that the latest flash attn changed the argument list and the shape of the return value. I've just fixed them in #45 and #47. Could you pull the latest code and give another try? Thank you.

lambda7xx commented 1 month ago

@lambda7xx @aniki-ly sorry for the late reply. The reason for the error is that the latest flash attn changed the argument list and the shape of the return value. I've just fixed them in #45 and #47. Could you pull the latest code and give another try? Thank you.

thank you so much