Closed foreverpiano closed 3 months ago
the flexattention kernel is compiled by
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from functools import lru_cache
torch._dynamo.config.cache_size_limit = 1000
# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False) #, mode="max-autotune-no-cudagraphs")
I want to keep the kernel static and avoid using recompile. How can I do that?
This shouldn't require a recompile. Where are you seeing that it needed to recompile?
code to reference