pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
490 stars 24 forks source link

when using the same mask and same query/key/value.shape, how to fix the kernel instead of recompiling the flexattention? #28

Closed foreverpiano closed 3 months ago

foreverpiano commented 3 months ago
print("flex:", layer_index)

def expand_to_128(tensor):
    padding_size = 128 - tensor.size(-1)
    return torch.nn.functional.pad(tensor, (0, padding_size))

query_expanded = expand_to_128(query)
key_expanded = expand_to_128(key)
value_expanded = expand_to_128(value)

@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
    return block_mask

def noop(score, b, h, q_idx, kv_idx):
    return score
print(query_expanded.shape, key_expanded.shape, value_expanded.shape)

torch.cuda.synchronize()
before_time = time.perf_counter()
block_mask = create_block_mask_cached(prefix_lm_causal_mask, 1, 1, seq_len, seq_len)

hidden_states = flex_attention(query_expanded, key_expanded, value_expanded, block_mask=block_mask, scale=1./math.sqrt(d_k))

del block_mask
torch.cuda.synchronize()
end_time = time.perf_counter()
print("flex_attn: ", end_time - before_time)                        

def shrink_to_96(tensor):
    return tensor[..., :96]

hidden_states = shrink_to_96(hidden_states)

code to reference

foreverpiano commented 3 months ago

the flexattention kernel is compiled by

from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from functools import lru_cache 
torch._dynamo.config.cache_size_limit = 1000

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False) #, mode="max-autotune-no-cudagraphs")
foreverpiano commented 3 months ago

I want to keep the kernel static and avoid using recompile. How can I do that?

Chillee commented 3 months ago

This shouldn't require a recompile. Where are you seeing that it needed to recompile?