HKUNLP / ChunkLlama

[ICML'24] Data and code for our paper "Training-Free Long-Context Scaling of Large Language Models"
Apache License 2.0
331 stars 18 forks source link

A Bug in flash_decoding_chunkllama.py #24

Open ZayIsAllYouNeed opened 1 month ago

ZayIsAllYouNeed commented 1 month ago

Have not from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa

flash_decoding_chunkllama.py:

ChenxinAn-fdu commented 1 month ago

Hi thank you for pointing this but we usually use flash-attention in long-context settings. If you set attn_implementation="flash_attention_2",, the error should not happen.

code

    if self._use_flash_attention_2:
        # 2d mask is passed through the layers
        attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
    elif self._use_sdpa and not output_attentions:
        # output_attentions=True can not be supported when using SDPA, and we fall back on
        # the manual implementation that requires a 4D causal mask in all cases.
        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
            attention_mask,
            (batch_size, seq_length),
            inputs_embeds,
            past_key_values_length,
        )
    else:
        # 4d mask is passed through the layers
        attention_mask = _prepare_4d_causal_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )