pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
475 stars 23 forks source link

Test with random cross attention #67

Closed ssmmnn11 closed 3 weeks ago

ssmmnn11 commented 3 weeks ago

Hi, this is great work!

I wanted to test having arbitrary mask, but cannot get it to work. I was wondering if this is the right / wrong way to do it? I tried it with the code below, but with mask_mod2 I get

../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [1841,0,0], thread: [32,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.

. . . ... lib/python3.11/site-packages/torch/_ops.py:1116, in OpOverloadPacket.call(self, *args, *kwargs) 1114 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): 1115 return _call_overload_packet_from_python(self, args, kwargs) -> 1116 return self._op(args, **(kwargs or {}))

RuntimeError: CUDA error: device-side assert triggered Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

===> example code:

import torch

from torch.nn.attention.flex_attention import (
    create_mask,
    create_block_mask
)

q = torch.randn([1, 1, 162, 16])
k = torch.randn([1, 1, 5248, 16])
v = torch.randn([1, 1, 5248, 16])

zrand1 = torch.randint(0, 2, [k.shape[2]], device="cuda")
def mask_mod1(b, h, q_idx, kv_idx):
    return zrand1[kv_idx] > 0

zrand2 = torch.randint(0, 2, [q.shape[2], k.shape[2]], device="cuda")
def mask_mod2(b, h, q_idx, kv_idx):
    return zrand2[q_idx, kv_idx] > 0

mask_test1 = create_mask(mask_mod1, None, None, q.shape[2], v.shape[2], "cuda")
mask_test2 = create_mask(mask_mod2, None, None, q.shape[2], v.shape[2], "cuda")

# seems to work ...
block_mask1 = create_block_mask(mask_mod1, B=None, H=None, Q_LEN=q.shape[2], KV_LEN=v.shape[2])

# does not work :-(
block_mask = create_block_mask(mask_mod2, B=None, H=None, Q_LEN=q.shape[2], KV_LEN=v.shape[2])
ssmmnn11 commented 3 weeks ago

sequence length was not multiple of 128