feifeibear / long-context-attention

USP: Unified (a.k.a. Hybrid, 2D) Sequence Parallel Attention for Long Context Transformers Model Training and Inference
Apache License 2.0
350 stars 24 forks source link

flash_attn version dependency #64

Closed Eigensystem closed 2 months ago

Eigensystem commented 3 months ago

Description

The following error occurs when using flash_attn==2.6.1:

[rank1]:   File ".../lib/python3.10/site-packages/yunchang/ring/ring_flash_attn.py", line 33, in ring_flash_attn_forward
[rank1]:     block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
[rank1]: TypeError: _flash_attn_forward() missing 1 required positional argument: 'softcap'

For flash_attn>=2.6.0.post1, the function signature of _flash_attn_forward is:

def _flash_attn_forward(
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
)

For flash_attn<=2.5.9.post1, the function signature of _flash_attn_forward is:

def _flash_attn_forward(
    q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
)

Solution

  1. in setup.py:

    install_requires=[
        'flash-attn<=2.5.9.post1',
    ],
  2. or modify the interface to pass the softcap parameter

feifeibear commented 2 months ago

close with #70