pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
339 stars 13 forks source link

`error: 'tt.broadcast' op requires the same encoding for all operands and results` for local window attention #26

Open fteufel opened 3 weeks ago

fteufel commented 3 weeks ago


Thank you for providing this collection! I'm trying to get local window attention to run. I managed to have a simple example running locally as shown in #15, but I am facing problems now when I try to wrap everything into a module and use it in an actual transformer.

Specifically, I'm facing a triton error that is beyond my understanding:

loc(callsite(callsite("/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/":493:60 at "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/":528:12) at "/tmp/torchinductor_fegt/vf/":35:67)): error: 'tt.broadcast' op requires the same encoding for all operands and results
The above exception was the direct cause of the following exception:
  File "/home/felix/", line 55, in torch_dynamo_resume_in_forward_at_54
    sliding_window_mask = generate_sliding_window(window_size=self.window_size)
  File "/home/felix/", line 57, in torch_dynamo_resume_in_forward_at_55
    block_mask = create_block_mask(
  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/nn/attention/", line 800, in create_block_mask
    inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)

Code that produces the error:

import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, _mask_mod_signature
from einops import rearrange

def generate_sliding_window(window_size: int) -> _mask_mod_signature:

    def sliding_window(b, h, q_idx, kv_idx):
        del b, h # not used
        return torch.abs(q_idx - kv_idx) <= window_size // 2

    sliding_window_mask = sliding_window
    sliding_window_mask.__name__ = f"sliding_window_{window_size}"
    return sliding_window_mask

class Attention(nn.Module):
    def __init__(
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        bias: bool = False,
        self.window_size = window_size
        self.heads = heads
        inner_dim = dim_head * heads

        self.to_q = nn.Linear(dim, inner_dim, bias = bias)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
        self.to_out = nn.Linear(inner_dim, dim, bias=bias)

    def forward(self, x, context = None):

        kv_input = x if context is None else context
        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h = self.heads), (q, k, v))

        # Flex attention block:
        q = rearrange(q, 'b n h d -> b h n d')
        k = rearrange(k, 'b n h d -> b h n d')
        v = rearrange(v, 'b n h d -> b h n d')

        l_q = q.shape[2]
        l_k = k.shape[2]
        print(q.shape, k.shape, v.shape)
        sliding_window_mask = generate_sliding_window(window_size=self.window_size)

        block_mask = create_block_mask(
        sliding_window_mask, B=None, H=None, Q_LEN=l_q, KV_LEN=l_k, _compile=True
        out = flex_attention(q, k, v, block_mask=block_mask)

        # Flex attention block end.

        out = rearrange(out, 'b h n d -> b n h d')
        # (b, n, h, d)
        out = rearrange(out, 'b n h d -> b n (h d)', h = self.heads)

        out = self.to_out(out)

        return out

device= torch.device('cuda')
model = Attention(768, 6, 128, window_size=64)

test_inp = torch.ones(1,65536,768).to(device)
out = model(test_inp)

Any ideas what might be going wrong here?

Torch version 2.5.0.dev20240815+cu118

drisspg commented 3 weeks ago

The problem is that we dont support wrapping, "create_block_mask" in torch.compile instead BlockMask should be created outside of the region and than ammortized across all the attention Layers. A "hack" to get it to work though is:

import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, _mask_mod_signature

def generate_sliding_window(window_size: int) -> _mask_mod_signature:

    def sliding_window(b, h, q_idx, kv_idx):
        del b, h # not used
        return torch.abs(q_idx - kv_idx) <= window_size // 2

    sliding_window_mask = sliding_window
    sliding_window_mask.__name__ = f"sliding_window_{window_size}"
    return sliding_window_mask

class Attention(nn.Module):
    def __init__(
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        bias: bool = False,
        self.window_size = window_size
        self.heads = heads
        inner_dim = dim_head * heads

        self.to_q = nn.Linear(dim, inner_dim, bias = bias)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
        self.to_out = nn.Linear(inner_dim, dim, bias=bias)

    def create_block_mask(self, l_q, l_k):
        sliding_window_mask = generate_sliding_window(window_size=self.window_size)
        block_mask = create_block_mask(
                sliding_window_mask, B=None, H=None, Q_LEN=l_q, KV_LEN=l_k, _compile=True
        return block_mask

    def forward(self, x, context=None):
        kv_input = x if context is None else context
        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)

        # Reshape q, k, v
        b, n, _ = q.shape
        q = q.view(b, n, self.heads, -1).transpose(1, 2)
        k = k.view(b, n, self.heads, -1).transpose(1, 2)
        v = v.view(b, n, self.heads, -1).transpose(1, 2)

        # Flex attention block:
        l_q = q.shape[2]
        l_k = k.shape[2]
        print(q.shape, k.shape, v.shape)

        block_mask = self.create_block_mask(l_q, l_k)
        out = flex_attention(q, k, v, block_mask=block_mask)
        # Flex attention block end.

        # Reshape output
        out = out.transpose(1, 2).contiguous()
        out = out.view(b, n, -1)
        out = self.to_out(out)
        return out

device= torch.device('cuda')
model = Attention(768, 6, 128, window_size=64)

test_inp = torch.ones(1,65536,768).to(device)
out = model(test_inp)
fteufel commented 3 weeks ago

Thanks for the quick reply!

I still get the error when using the @torch._dynamo.disable() approach.

My layers will have different block masks - so just defining one globally wouldn't work, i'd really like them to be organized with the modules somehow. Is this simply not supported?

drisspg commented 3 weeks ago

hmm interesting, so the code above I linked is erroring for you? I wonder why I cant repro

fteufel commented 3 weeks ago

Yeah, copied it into a script, and still

loc(callsite(callsite("/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/":493:60 at "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/":528:12) at "/tmp/torchinductor_fegt/vf/":35:67)): error: 'tt.broadcast' op requires the same encoding for all operands and results
  File "/home/felix/", line 61, in torch_dynamo_resume_in_forward_at_59
    block_mask = self.create_block_mask(l_q, l_k)
Chillee commented 3 weeks ago

@fteufel I think you'll also need to make sure you have the latest pytorch nightly - we added a workaround for this triton issue in the last couple days:

Also, can you create your block mask at initialization time? You'll want to move the block mask construction out of the critical path anyways, as block mask construction is somewhat expensive:

fteufel commented 3 weeks ago

@Chillee thanks - upgrading to the latest build changed things. Now I get

Traceback (most recent call last):
  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/", line 1287, in load
  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/", line 1261, in _check_can_cache
    raise BypassFxGraphCache("Can't cache HigherOrderOperators.")
torch._inductor.codecache.BypassFxGraphCache: Can't cache HigherOrderOperators.

  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/triton/compiler/", line 374, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

That doesn't feel quite right though, I'm running this on an A10 with 23GB and would expect this example to be doable.

Regarding critical paths - in the given example, would adding an lru_cache to create_block_mask achieve that, or should we rather pass the query+key lengths to __init__ and then construct+register the block mask there?

Chillee commented 3 weeks ago

That doesn't feel quite right though, I'm running this on an A10 with 23GB and would expect this example to be doable.

I think it's an issue with our config option being too big for the GPU.

Regarding critical paths - in the given example, would adding an lru_cache to create_block_mask achieve that, or should we rather pass the query+key lengths to init and then construct+register the block mask there?

the second one would work I think, or doing some kind of manual cache. I think we just ignore lru_cache when compiling.

fteufel commented 3 weeks ago

Is there anything on my end that i can try to make it run? The BLOCK_SIZE argument does not seem to change anything. The same code also runs without compilation on the same GPU.

Manual caching - do you mean something like this in forward?

if self.block_mask is None:
    l_q, l_k = q.shape[2], k.shape[2]
    self.block_mask =  self.create_block_mask(l_q, l_k)

Avoiding having to feed in l_q and l_k as hyperparameters to __init__ would be nice - it just doesn't feel very torch-like compared to how we coded transformers so far.

Chillee commented 3 weeks ago

@fteufel Perhaps doing torch.compile(mode="max-autotune-no-cudagraphs") would allow you to workaround it?

fteufel commented 3 weeks ago

Thanks for the suggestion - unfortunately it doesn't do it. I see the same out of resource exceptions printed, until eventually I hit

  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/", line 1336, in do_autotuning
    raise NoValidChoicesError
torch._inductor.exc.LoweringException: NoValidChoicesError: 
  target: flex_attention
drisspg commented 3 weeks ago

Yeah, I think we need to add some more fallback_options for autotuning If you are feeling up to help and are down to hack up your site packages: /home/drisspg/.conda/envs/ao/lib/python3.12/site-packages/torch/_inductor/kernel/

This is where the flex_attention file is installed in my local conda env. If you hack these up:

you can try setting the block_size to a smaller power of 2 or to a less stages.

It also might be the case that you are hitting the decoding kernel:

and again you can try setting the return choice here to be smaller in block sizes and stages

fteufel commented 3 weeks ago

Cool - tried it.

  1. I actually hit the A100 branch on my A10 GPU.

    elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0):  # A100
        print('**** A100')
    **** A100
    Traceback (most recent call last):
    [...same error as before]
  2. Made it work like

    # commented this out
    # default_config = _a100_default_config.get((dtype, head_dim), default_config)
    default_config = (32, 16, 4, 3)

So the default modest hardware fallback would be fine, we just didn't get there apparently.