huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.11k stars 27.04k forks source link

static cache implementation is not compatible with attn_implementation==flash_attention_2 #32040

Open faaany opened 4 months ago

faaany commented 4 months ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

pytest -rA tests/test_cache_utils.py::CacheIntegrationTest -k "test_static_cache_greedy_decoding_pad_left and flash_attention"

fails with

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if isinstance(past_key_value, StaticCache):
>           raise ValueError(
                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )
E           ValueError: `static` cache implementation is not compatible with `attn_implementation==flash_attention_2` make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers

src/transformers/models/llama/modeling_llama.py:388: ValueError

And the right padding test case also fails:

pytest -rA tests/test_cache_utils.py::CacheIntegrationTest -k "test_static_cache_greedy_decoding_pad_right and flash_attention"

Expected behavior

Either we don't test flash_attention in this case, or we should add a if check to skip setting cache_implementation to static.

faaany commented 4 months ago

I made a possible fix suggestion in this PR draft: https://github.com/huggingface/transformers/pull/32039. But I am not sure whether this is correct. So I also filed this issue.

amyeroberts commented 4 months ago

cc @gante too

zucchini-nlp commented 3 months ago

Incompatibility also affecting Gemma2 with flash-attn, as it doesn't support dynamic cache