pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
339 stars 13 forks source link

`error: 'tt.broadcast' op requires the same encoding for all operands and results` for local window attention #26

Open fteufel opened 3 weeks ago

fteufel commented 3 weeks ago

Hi,

Thank you for providing this collection! I'm trying to get local window attention to run. I managed to have a simple example running locally as shown in #15, but I am facing problems now when I try to wrap everything into a module and use it in an actual transformer.

Specifically, I'm facing a triton error that is beyond my understanding:

loc(callsite(callsite("/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_helpers.py":493:60 at "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_helpers.py":528:12) at "/tmp/torchinductor_fegt/vf/cvfsti7kjet6sssuaaflt2c6tq2tlouso74cd46rgp5yl6aayvhg.py":35:67)): error: 'tt.broadcast' op requires the same encoding for all operands and results
....
The above exception was the direct cause of the following exception:
....
  File "/home/felix/test.py", line 55, in torch_dynamo_resume_in_forward_at_54
    sliding_window_mask = generate_sliding_window(window_size=self.window_size)
  File "/home/felix/test.py", line 57, in torch_dynamo_resume_in_forward_at_55
    block_mask = create_block_mask(
  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 800, in create_block_mask
    inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
....

Code that produces the error:

import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, _mask_mod_signature
from einops import rearrange

def generate_sliding_window(window_size: int) -> _mask_mod_signature:

    def sliding_window(b, h, q_idx, kv_idx):
        del b, h # not used
        return torch.abs(q_idx - kv_idx) <= window_size // 2

    sliding_window_mask = sliding_window
    sliding_window_mask.__name__ = f"sliding_window_{window_size}"
    return sliding_window_mask

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        window_size=64,
        bias: bool = False,
    ):
        super().__init__()
        self.window_size = window_size
        self.dim=dim
        self.heads = heads
        inner_dim = dim_head * heads

        self.to_q = nn.Linear(dim, inner_dim, bias = bias)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
        self.to_out = nn.Linear(inner_dim, dim, bias=bias)

    @torch.compile()
    def forward(self, x, context = None):

        kv_input = x if context is None else context
        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h = self.heads), (q, k, v))

        ######################
        # Flex attention block:
        q = rearrange(q, 'b n h d -> b h n d')
        k = rearrange(k, 'b n h d -> b h n d')
        v = rearrange(v, 'b n h d -> b h n d')

        l_q = q.shape[2]
        l_k = k.shape[2]
        print(q.shape, k.shape, v.shape)
        sliding_window_mask = generate_sliding_window(window_size=self.window_size)

        block_mask = create_block_mask(
        sliding_window_mask, B=None, H=None, Q_LEN=l_q, KV_LEN=l_k, _compile=True
        )
        out = flex_attention(q, k, v, block_mask=block_mask)

        # Flex attention block end.
        #######################

        out = rearrange(out, 'b h n d -> b n h d')
        # (b, n, h, d)
        out = rearrange(out, 'b n h d -> b n (h d)', h = self.heads)

        out = self.to_out(out)

        return out

device= torch.device('cuda')
model = Attention(768, 6, 128, window_size=64)
model.to(device)

test_inp = torch.ones(1,65536,768).to(device)
out = model(test_inp)

Any ideas what might be going wrong here?

Torch version 2.5.0.dev20240815+cu118

drisspg commented 3 weeks ago

The problem is that we dont support wrapping, "create_block_mask" in torch.compile instead BlockMask should be created outside of the region and than ammortized across all the attention Layers. A "hack" to get it to work though is:

import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, _mask_mod_signature

def generate_sliding_window(window_size: int) -> _mask_mod_signature:

    def sliding_window(b, h, q_idx, kv_idx):
        del b, h # not used
        return torch.abs(q_idx - kv_idx) <= window_size // 2

    sliding_window_mask = sliding_window
    sliding_window_mask.__name__ = f"sliding_window_{window_size}"
    return sliding_window_mask

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        window_size=64,
        bias: bool = False,
    ):
        super().__init__()
        self.window_size = window_size
        self.dim=dim
        self.heads = heads
        inner_dim = dim_head * heads

        self.to_q = nn.Linear(dim, inner_dim, bias = bias)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
        self.to_out = nn.Linear(inner_dim, dim, bias=bias)

    @torch._dynamo.disable()
    def create_block_mask(self, l_q, l_k):
        sliding_window_mask = generate_sliding_window(window_size=self.window_size)
        block_mask = create_block_mask(
                sliding_window_mask, B=None, H=None, Q_LEN=l_q, KV_LEN=l_k, _compile=True
            )
        return block_mask

    @torch.compile()
    def forward(self, x, context=None):
        kv_input = x if context is None else context
        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)

        # Reshape q, k, v
        b, n, _ = q.shape
        q = q.view(b, n, self.heads, -1).transpose(1, 2)
        k = k.view(b, n, self.heads, -1).transpose(1, 2)
        v = v.view(b, n, self.heads, -1).transpose(1, 2)

        ######################
        # Flex attention block:
        l_q = q.shape[2]
        l_k = k.shape[2]
        print(q.shape, k.shape, v.shape)

        block_mask = self.create_block_mask(l_q, l_k)
        out = flex_attention(q, k, v, block_mask=block_mask)
        # Flex attention block end.
        #######################

        # Reshape output
        out = out.transpose(1, 2).contiguous()
        out = out.view(b, n, -1)
        out = self.to_out(out)
        return out

device= torch.device('cuda')
model = Attention(768, 6, 128, window_size=64)
model.to(device)

test_inp = torch.ones(1,65536,768).to(device)
out = model(test_inp)
fteufel commented 3 weeks ago

Thanks for the quick reply!

I still get the error when using the @torch._dynamo.disable() approach.

My layers will have different block masks - so just defining one globally wouldn't work, i'd really like them to be organized with the modules somehow. Is this simply not supported?

drisspg commented 3 weeks ago

hmm interesting, so the code above I linked is erroring for you? I wonder why I cant repro

fteufel commented 3 weeks ago

Yeah, copied it into a script, and still

loc(callsite(callsite("/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_helpers.py":493:60 at "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_helpers.py":528:12) at "/tmp/torchinductor_fegt/vf/cvfsti7kjet6sssuaaflt2c6tq2tlouso74cd46rgp5yl6aayvhg.py":35:67)): error: 'tt.broadcast' op requires the same encoding for all operands and results
....
  File "/home/felix/test2.py", line 61, in torch_dynamo_resume_in_forward_at_59
    block_mask = self.create_block_mask(l_q, l_k)
....
Chillee commented 3 weeks ago

@fteufel I think you'll also need to make sure you have the latest pytorch nightly - we added a workaround for this triton issue in the last couple days: https://github.com/pytorch/pytorch/pull/133413

Also, can you create your block mask at initialization time? You'll want to move the block mask construction out of the critical path anyways, as block mask construction is somewhat expensive: https://pytorch.org/blog/flexattention/#q-how-can-we-compute-blockmask-quicker

fteufel commented 3 weeks ago

@Chillee thanks - upgrading to the latest build changed things. Now I get

Traceback (most recent call last):
  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1287, in load
    FxGraphCache._check_can_cache(gm)
  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1261, in _check_can_cache
    raise BypassFxGraphCache("Can't cache HigherOrderOperators.")
torch._inductor.codecache.BypassFxGraphCache: Can't cache HigherOrderOperators.

...
  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/triton/compiler/compiler.py", line 374, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

That doesn't feel quite right though, I'm running this on an A10 with 23GB and would expect this example to be doable.

Regarding critical paths - in the given example, would adding an lru_cache to create_block_mask achieve that, or should we rather pass the query+key lengths to __init__ and then construct+register the block mask there?

Chillee commented 3 weeks ago

That doesn't feel quite right though, I'm running this on an A10 with 23GB and would expect this example to be doable.

I think it's an issue with our config option being too big for the GPU.

Regarding critical paths - in the given example, would adding an lru_cache to create_block_mask achieve that, or should we rather pass the query+key lengths to init and then construct+register the block mask there?

the second one would work I think, or doing some kind of manual cache. I think we just ignore lru_cache when compiling.

fteufel commented 3 weeks ago

Is there anything on my end that i can try to make it run? The BLOCK_SIZE argument does not seem to change anything. The same code also runs without compilation on the same GPU.

Manual caching - do you mean something like this in forward?

if self.block_mask is None:
    l_q, l_k = q.shape[2], k.shape[2]
    self.block_mask =  self.create_block_mask(l_q, l_k)

Avoiding having to feed in l_q and l_k as hyperparameters to __init__ would be nice - it just doesn't feel very torch-like compared to how we coded transformers so far.

Chillee commented 3 weeks ago

@fteufel Perhaps doing torch.compile(mode="max-autotune-no-cudagraphs") would allow you to workaround it?

fteufel commented 3 weeks ago

Thanks for the suggestion - unfortunately it doesn't do it. I see the same out of resource exceptions printed, until eventually I hit

  File "/home/felix/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1336, in do_autotuning
    raise NoValidChoicesError
torch._inductor.exc.LoweringException: NoValidChoicesError: 
  target: flex_attention
drisspg commented 3 weeks ago

Yeah, I think we need to add some more fallback_options for autotuning If you are feeling up to help and are down to hack up your site packages: /home/drisspg/.conda/envs/ao/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py

This is where the flex_attention file is installed in my local conda env. If you hack these up: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_attention.py#L545-L592

you can try setting the block_size to a smaller power of 2 or to a less stages.

It also might be the case that you are hitting the decoding kernel: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_decoding.py#L304-L314

and again you can try setting the return choice here to be smaller in block sizes and stages

fteufel commented 3 weeks ago

Cool - tried it.

  1. I actually hit the A100 branch on my A10 GPU.

    elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0):  # A100
        print('**** A100')
    **** A100
    Traceback (most recent call last):
    [...same error as before]
  2. Made it work like

    # commented this out
    # default_config = _a100_default_config.get((dtype, head_dim), default_config)
    default_config = (32, 16, 4, 3)

So the default modest hardware fallback would be fine, we just didn't get there apparently.