pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
478 stars 23 forks source link

How to avoid re-compute mask #34

Open NonvolatileMemory opened 2 months ago

NonvolatileMemory commented 2 months ago

Hi FlexAttention Team,

Thanks for your code.

I use flex attention to impl a fast io-aware streaming attention using this mask:

def sliding_window_causal_with_stream(b, h, q_idx, kv_idx):
    # Causal mask ensures no future positions are attended
    causal_mask = q_idx >= kv_idx
    # Sliding window mask restricts the view within a window size
    window_mask = q_idx - kv_idx <= 256 
    # Stream mask ensures that for q_idx >= 4, kv_idx <= 4 is always visible
    stream_mask = (q_idx >= 4) & (kv_idx <= 4)
    # Combine all masks: sliding window and causal mask, or stream mask
    return (causal_mask & window_mask) | stream_mask

if I compile as the following:

block_mask = create_block_mask(sliding_window_causal_with_stream, B=None, H=None, Q_LEN=8192, KV_LEN=8192,  _compile=True)
flex_attention = torch.compile(partial(flex_attention, block_mask=block_mask, enable_gqa=True))

It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.

Do I need to rebuild the block mask? Or can I reuse the first 8192 one for every input less than 8192?

drisspg commented 2 months ago

Whenever the Q_LEN or KV_LEN changes you will need to recompute the mask. It is possible to create a new smaller block mask from the larger version but you would need to manually do this from the internal tensor components of the BlockMask

NonvolatileMemory commented 2 months ago

Thanks for your reply! In my view lots of masks are qlen and kvlen agnostic, maybe it is better to define a regenerate_mask given the new_input_len and max_len_block_mask.

drisspg commented 2 months ago

We recently added indexing to the block_mask, context is decoding where you can create a larger block_mask and then slice into it + add the correct mask_mod: https://github.com/pytorch/pytorch/commit/09a339fc0605f7dad807efe6419de71ab28aafb7#diff-fdd6d17efe145eae3f8090031505ec062fc47ede339275a73c5e9e52c702dc91

Chillee commented 2 months ago

@NonvolatileMemory You can also just create a larger block mask to start with, and then reuse that mask - we support passing in a blockmask that was defined for a larger sequence than you're currently calling it with.

NonvolatileMemory commented 2 months ago

@Chillee I think it will cause bug. I cannot pass allclose when input size is 4096 but mask is defined by 8192

joydddd commented 2 months ago

It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.

Could you clarify which config would trigger the bug? Is it the blockmask defined as KV_LEN=8192, Q_LEN=8192,

query passed in has a length of 1024 and k/v has a length of 8192?

NonvolatileMemory commented 6 days ago

It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.

Could you clarify which config would trigger the bug? Is it the blockmask defined as KV_LEN=8192, Q_LEN=8192,

query passed in has a length of 1024 and k/v has a length of 8192?

Hi, here is my source code

from torch.nn.attention.flex_attention import (
    _DEFAULT_SPARSE_BLOCK_SIZE,
    create_block_mask,
    create_mask,
    flex_attention,
)
import torch
from functools import lru_cache, partial

def block_mask(b, h, q_idx, kv_idx):
    q_block = q_idx // 4
    kv_block = kv_idx // 4
    return q_block > kv_block
block_mask = create_block_mask(block_mask, B=None, H=None, Q_LEN=4096, KV_LEN=4096,  _compile=True)
flex_attn = torch.compile(partial(flex_attention, block_mask=block_mask, enable_gqa=True))

import torch
import torch.nn.functional as F

def torch_mask(q_idx, kv_idx, block_size=4):
    return q_idx // block_size > kv_idx // block_size

def diff(bsz=4, seq_len=128 * 20, d_head=128, num_heads=8, block_size=4):
    # torch_attn

    Q = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
    K = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
    V = torch.randn(bsz, num_heads, seq_len, d_head).cuda()

    scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (Q.size(-1) ** 0.5)

    q_idx = torch.arange(seq_len).view(-1, 1)
    kv_idx = torch.arange(seq_len).view(1, -1)
    mask = torch_mask(q_idx, kv_idx, block_size)[None, None, :, :].cuda()

    scores = scores.masked_fill(~mask, float('-inf'))
    attn_weights = F.softmax(scores, dim=-1)
    torch_out = torch.matmul(attn_weights, V)
    flex_out = flex_attn(Q, K, V)
    return (flex_out[:, :, 16:] - torch_out[:, :, 16:]).max()
a = diff()
print(a)
# tensor(1.2792, device='cuda:0')