Open NonvolatileMemory opened 2 months ago
Whenever the Q_LEN or KV_LEN changes you will need to recompute the mask. It is possible to create a new smaller block mask from the larger version but you would need to manually do this from the internal tensor components of the BlockMask
Thanks for your reply! In my view lots of masks are qlen and kvlen agnostic, maybe it is better to define a regenerate_mask given the new_input_len and max_len_block_mask.
We recently added indexing to the block_mask, context is decoding where you can create a larger block_mask and then slice into it + add the correct mask_mod: https://github.com/pytorch/pytorch/commit/09a339fc0605f7dad807efe6419de71ab28aafb7#diff-fdd6d17efe145eae3f8090031505ec062fc47ede339275a73c5e9e52c702dc91
@NonvolatileMemory You can also just create a larger block mask to start with, and then reuse that mask - we support passing in a blockmask that was defined for a larger sequence than you're currently calling it with.
@Chillee I think it will cause bug. I cannot pass allclose when input size is 4096 but mask is defined by 8192
It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.
Could you clarify which config would trigger the bug? Is it the blockmask defined as KV_LEN=8192, Q_LEN=8192,
query passed in has a length of 1024 and k/v has a length of 8192?
It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.
Could you clarify which config would trigger the bug? Is it the blockmask defined as KV_LEN=8192, Q_LEN=8192,
query passed in has a length of 1024 and k/v has a length of 8192?
Hi, here is my source code
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
create_block_mask,
create_mask,
flex_attention,
)
import torch
from functools import lru_cache, partial
def block_mask(b, h, q_idx, kv_idx):
q_block = q_idx // 4
kv_block = kv_idx // 4
return q_block > kv_block
block_mask = create_block_mask(block_mask, B=None, H=None, Q_LEN=4096, KV_LEN=4096, _compile=True)
flex_attn = torch.compile(partial(flex_attention, block_mask=block_mask, enable_gqa=True))
import torch
import torch.nn.functional as F
def torch_mask(q_idx, kv_idx, block_size=4):
return q_idx // block_size > kv_idx // block_size
def diff(bsz=4, seq_len=128 * 20, d_head=128, num_heads=8, block_size=4):
# torch_attn
Q = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
K = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
V = torch.randn(bsz, num_heads, seq_len, d_head).cuda()
scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (Q.size(-1) ** 0.5)
q_idx = torch.arange(seq_len).view(-1, 1)
kv_idx = torch.arange(seq_len).view(1, -1)
mask = torch_mask(q_idx, kv_idx, block_size)[None, None, :, :].cuda()
scores = scores.masked_fill(~mask, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
torch_out = torch.matmul(attn_weights, V)
flex_out = flex_attn(Q, K, V)
return (flex_out[:, :, 16:] - torch_out[:, :, 16:]).max()
a = diff()
print(a)
# tensor(1.2792, device='cuda:0')
Hi FlexAttention Team,
Thanks for your code.
I use flex attention to impl a fast io-aware streaming attention using this mask:
if I compile as the following:
It will give correct and when qlen, klen, vlen = 8192 or 1024, but wrong and when qlen=1024.
Do I need to rebuild the block mask? Or can I reuse the first 8192 one for every input less than 8192?