pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
408 stars 19 forks source link

FlexAttention results do not match FlashAttention results #50

Open tilmto opened 1 week ago

tilmto commented 1 week ago

Hi,

I noticed that for certain sequence lengths, given the same inputs, FlexAttention's output differs from FlashAttention's output.

For example, with an input sequence of length 137, FlexAttention with a causal mask will, by default, create a block mask of length 256. In this case, when inputting the same set of Q, K, and V to both FlexAttention and FlashAttention, the first 128 tokens of their output will be identical, but the remaining 9 tokens will show slight differences.

I’m wondering if this discrepancy is due to some specific handling of FlexAttention's attention mask, and how we can address it. Thanks!

drisspg commented 1 week ago

Do you a repro? Since they are two different implementations with potentially different reduction orders it is expected that there will be some slight numerical differences within the bounds specified here: https://pytorch.org/docs/stable/notes/numerical_accuracy.html

No BlockM and BlockN aligned shapes do need special care (padding and masking locally) to handle which is the case that you run into with sequence length = 137

tilmto commented 1 week ago

Hi,

Thanks for your prompt response! I have provided a simple pseudo code below to illustrate my usage of FlexAttention (if a runnable code snippet is needed, please let me know).

Basically, to handle arbitrary input sequence lengths, I create the block mask on the fly. Based on this implementation, I found that when the input sequence has irregular lengths, such as 137, the block mask is automatically padded to 256. As a result, the generation of the last 9 tokens differs from the results when the same sequence is input to FlashAttention, while the first 128 tokens remain the same for both attention implementations. In contrast, if the input sequence has a length of exactly 256, the outputs will be the same for both attention implementations.

Could you let me know if we did anything wrong regarding the block_mask usage. Thanks!

class FlexAttention(LlamaAttention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        from torch.nn.attention.flex_attention import flex_attention, create_block_mask, and_masks, or_masks
        from functools import partial

        self.create_block_mask = create_block_mask

        def causal_mask(b, h, q_idx, kv_idx):
            return q_idx >= kv_idx

        self.causal_mask = causal_mask

        self.flex_attention = torch.compile(flex_attention)

    def forward(
            self,
            hidden_states: torch.Tensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Cache] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            **kwargs,
    ):  

        ## Omit codes for generating query_states, key_states, value_states

        block_mask = self.create_block_mask(self.causal_mask, B=None, H=None, Q_LEN=key_states.shape[-2], KV_LEN=key_states.shape[-2])

        attn_output = self.flex_attention(query_states, key_states, value_states, block_mask=block_mask)

        return attn_output
drisspg commented 1 week ago

Do you think you could provide a runnable example, we did have some early bugs when we initially added non multiple of 128 sequence length support, however they should be resolved on nightly.