Open fteufel opened 3 weeks ago
The problem is that we dont support wrapping, "create_block_mask" in torch.compile instead BlockMask should be created outside of the region and than ammortized across all the attention Layers. A "hack" to get it to work though is:
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, _mask_mod_signature
def generate_sliding_window(window_size: int) -> _mask_mod_signature:
def sliding_window(b, h, q_idx, kv_idx):
del b, h # not used
return torch.abs(q_idx - kv_idx) <= window_size // 2
sliding_window_mask = sliding_window
sliding_window_mask.__name__ = f"sliding_window_{window_size}"
return sliding_window_mask
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
window_size=64,
bias: bool = False,
):
super().__init__()
self.window_size = window_size
self.dim=dim
self.heads = heads
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias = bias)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
self.to_out = nn.Linear(inner_dim, dim, bias=bias)
@torch._dynamo.disable()
def create_block_mask(self, l_q, l_k):
sliding_window_mask = generate_sliding_window(window_size=self.window_size)
block_mask = create_block_mask(
sliding_window_mask, B=None, H=None, Q_LEN=l_q, KV_LEN=l_k, _compile=True
)
return block_mask
@torch.compile()
def forward(self, x, context=None):
kv_input = x if context is None else context
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)
# Reshape q, k, v
b, n, _ = q.shape
q = q.view(b, n, self.heads, -1).transpose(1, 2)
k = k.view(b, n, self.heads, -1).transpose(1, 2)
v = v.view(b, n, self.heads, -1).transpose(1, 2)
######################
# Flex attention block:
l_q = q.shape[2]
l_k = k.shape[2]
print(q.shape, k.shape, v.shape)
block_mask = self.create_block_mask(l_q, l_k)
out = flex_attention(q, k, v, block_mask=block_mask)
# Flex attention block end.
#######################
# Reshape output
out = out.transpose(1, 2).contiguous()
out = out.view(b, n, -1)
out = self.to_out(out)
return out
device= torch.device('cuda')
model = Attention(768, 6, 128, window_size=64)
model.to(device)
test_inp = torch.ones(1,65536,768).to(device)
out = model(test_inp)
Thanks for the quick reply!
I still get the error when using the @torch._dynamo.disable()
approach.
My layers will have different block masks - so just defining one globally wouldn't work, i'd really like them to be organized with the modules somehow. Is this simply not supported?
hmm interesting, so the code above I linked is erroring for you? I wonder why I cant repro
Yeah, copied it into a script, and still
loc(callsite(callsite("/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_helpers.py":493:60 at "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_helpers.py":528:12) at "/tmp/torchinductor_fegt/vf/cvfsti7kjet6sssuaaflt2c6tq2tlouso74cd46rgp5yl6aayvhg.py":35:67)): error: 'tt.broadcast' op requires the same encoding for all operands and results
....
File "/home/felix/test2.py", line 61, in torch_dynamo_resume_in_forward_at_59
block_mask = self.create_block_mask(l_q, l_k)
....
@fteufel I think you'll also need to make sure you have the latest pytorch nightly - we added a workaround for this triton issue in the last couple days: https://github.com/pytorch/pytorch/pull/133413
Also, can you create your block mask at initialization time? You'll want to move the block mask construction out of the critical path anyways, as block mask construction is somewhat expensive: https://pytorch.org/blog/flexattention/#q-how-can-we-compute-blockmask-quicker
@Chillee thanks - upgrading to the latest build changed things. Now I get
Traceback (most recent call last):
File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1287, in load
FxGraphCache._check_can_cache(gm)
File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1261, in _check_can_cache
raise BypassFxGraphCache("Can't cache HigherOrderOperators.")
torch._inductor.codecache.BypassFxGraphCache: Can't cache HigherOrderOperators.
...
File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/triton/compiler/compiler.py", line 374, in _init_handles
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
That doesn't feel quite right though, I'm running this on an A10 with 23GB and would expect this example to be doable.
Regarding critical paths - in the given example, would adding an lru_cache to create_block_mask
achieve that, or should we rather pass the query+key lengths to __init__
and then construct+register the block mask there?
That doesn't feel quite right though, I'm running this on an A10 with 23GB and would expect this example to be doable.
I think it's an issue with our config option being too big for the GPU.
Regarding critical paths - in the given example, would adding an lru_cache to create_block_mask achieve that, or should we rather pass the query+key lengths to init and then construct+register the block mask there?
the second one would work I think, or doing some kind of manual cache. I think we just ignore lru_cache
when compiling.
Is there anything on my end that i can try to make it run? The BLOCK_SIZE
argument does not seem to change anything. The same code also runs without compilation on the same GPU.
Manual caching - do you mean something like this in forward
?
if self.block_mask is None:
l_q, l_k = q.shape[2], k.shape[2]
self.block_mask = self.create_block_mask(l_q, l_k)
Avoiding having to feed in l_q
and l_k
as hyperparameters to __init__
would be nice - it just doesn't feel very torch-like compared to how we coded transformers so far.
@fteufel Perhaps doing torch.compile(mode="max-autotune-no-cudagraphs")
would allow you to workaround it?
Thanks for the suggestion - unfortunately it doesn't do it. I see the same out of resource exceptions printed, until eventually I hit
File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1336, in do_autotuning
raise NoValidChoicesError
torch._inductor.exc.LoweringException: NoValidChoicesError:
target: flex_attention
Yeah, I think we need to add some more fallback_options for autotuning
If you are feeling up to help and are down to hack up your site packages:
/home/drisspg/.conda/envs/ao/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py
This is where the flex_attention file is installed in my local conda env. If you hack these up: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_attention.py#L545-L592
you can try setting the block_size to a smaller power of 2 or to a less stages.
It also might be the case that you are hitting the decoding kernel: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_decoding.py#L304-L314
and again you can try setting the return choice here to be smaller in block sizes and stages
Cool - tried it.
I actually hit the A100 branch on my A10 GPU.
elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
print('**** A100')
**** A100
Traceback (most recent call last):
[...same error as before]
Made it work like
# commented this out
# default_config = _a100_default_config.get((dtype, head_dim), default_config)
default_config = (32, 16, 4, 3)
So the default modest hardware fallback would be fine, we just didn't get there apparently.
Hi,
Thank you for providing this collection! I'm trying to get local window attention to run. I managed to have a simple example running locally as shown in #15, but I am facing problems now when I try to wrap everything into a module and use it in an actual transformer.
Specifically, I'm facing a triton error that is beyond my understanding:
Code that produces the error:
Any ideas what might be going wrong here?
Torch version
2.5.0.dev20240815+cu118