Open lambda7xx opened 3 months ago
the command
torchrun --nproc_per_node 4 test/test_ring_flash_attn_func.py
the error is
File "/home/xxxxs/ring-flash-attention/test/test_ring_flash_attn_func.py", line 97, in <module>
[rank0]: ring_out, ring_lse, _ = fn(
[rank0]: File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 214, in ring_flash_attn_qkvpacked_func
[rank0]: File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank0]: File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 159, in forward
[rank0]: File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 33, in ring_flash_attn_forward
[rank0]: TypeError: _flash_attn_forward() missing 1 required positional argument: 'softcap'
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/xxxxs/ring-flash-attention/test/test_ring_flash_attn_func.py", line 97, in <module>
[rank3]: ring_out, ring_lse, _ = fn(
[rank3]: File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 214, in ring_flash_attn_qkvpacked_func
[rank3]: File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
[rank3]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank3]: File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 159, in forward
[rank3]: File "/home/xxxxs/anaconda3/envs/stable/lib/python3.10/site-packages/ring_flash_attn-0.1-py3.10.egg/ring_flash_attn/ring_flash_attn.py", line 33, in ring_flash_attn_forward
[rank3]: TypeError: _flash_attn_forward() missing 1 required positional argument: 'softcap'
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/xxxxs/ring-flash-attention/test/test_ring_flash_attn_func.py", line 97,
same issue
@lambda7xx @aniki-ly sorry for the late reply. The reason for the error is that the latest flash attn changed the argument list and the shape of the return value. I've just fixed them in #45 and #47. Could you pull the latest code and give another try? Thank you.
@lambda7xx @aniki-ly sorry for the late reply. The reason for the error is that the latest flash attn changed the argument list and the shape of the return value. I've just fixed them in #45 and #47. Could you pull the latest code and give another try? Thank you.
thank you so much
I use 4 gpus to run the code. my command is
my error is