Open ZayIsAllYouNeed opened 1 month ago
Hi thank you for pointing this but we usually use flash-attention
in long-context settings. If you set attn_implementation="flash_attention_2",
, the error should not happen.
code
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
Have not from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
flash_decoding_chunkllama.py: