pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
475 stars 23 forks source link

Document masking does not work for small number of tokens? #68

Closed xidulu closed 3 weeks ago

xidulu commented 3 weeks ago
document_id = torch.zeros(100, dtype=torch.int, device="cpu")
document_id[:10] = 0
document_id[10:20] = 1
for i in range(20, 100, 20):
    document_id[i : i + 20] = i // 20 + 1

def document_causal_mask(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    document_mask = document_id[q_idx] == document_id[kv_idx]
    return causal_mask & document_mask

mask = create_block_mask(document_causal_mask, 1, 1, 100, 100, "cpu")
print(mask)

This would raise exception

File ~/work/anaconda3/lib/python3.9/site-packages/torch/utils/_device.py:106, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
    104 if func in _device_constructors() and kwargs.get('device') is None:
    105     kwargs['device'] = self.device
--> 106 return func(*args, **kwargs)

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_ops.py:1116, in OpOverloadPacket.__call__(self, *args, **kwargs)
   1114 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
   1115     return _call_overload_packet_from_python(self, args, kwargs)
-> 1116 return self._op(*args, **(kwargs or {}))

IndexError: index 100 is out of bounds for dimension 0 with size 100

which I believe is from the indexed operation document_id[q_idx]

drisspg commented 3 weeks ago

Hmm what version of PyTorch are you using?

I just ran this and got:

BlockMask(shape=(1, 1, 128, 128), sparsity=0.00%, 
(0, 0)
██
)
xidulu commented 3 weeks ago

2.5.1+cu118 is the version (which is the latest release version I believe)!

Screenshot 2024-10-31 at 13 03 13

Also see below for the complete stack trace:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Input In [3], in <cell line: 12>()
      9     document_mask = document_id[q_idx] == document_id[kv_idx]
     10     return causal_mask & document_mask
---> 12 mask = create_block_mask(document_causal_mask, 1, 1, 100, 100, "cpu")
     13 print(mask)

File ~/work/anaconda3/lib/python3.9/site-packages/torch/nn/attention/flex_attention.py:850, in create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, BLOCK_SIZE, _compile)
    848     inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
    849 with TransformGetItemToIndex():
--> 850     partial_block_mask, full_block_mask = inner_func(
    851         mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE
    852     )
    853     block_mask = _create_sparse_block_from_block_mask(
    854         (partial_block_mask, full_block_mask), mask_mod
    855     )
    856 return block_mask

File ~/work/anaconda3/lib/python3.9/site-packages/torch/nn/attention/flex_attention.py:775, in _create_block_mask_inner(mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE)
    761 def _create_block_mask_inner(
    762     mask_mod: Callable,
    763     B: int,
   (...)
    769     Q_BLOCK_SIZE: int,
    770 ):
    771     r"""Work around for being unable to instantiate __torch_function__ mode under compile.
    772     `create_block_mask` will compile this inner function and wrap the call to this
    773     with the __torch_function__ mode.
    774     """
--> 775     mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
    776     partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
    777         mask_tensor,
    778         KV_BLOCK_SIZE=KV_BLOCK_SIZE,
    779         Q_BLOCK_SIZE=Q_BLOCK_SIZE,
    780         separate_full_blocks=True,
    781     )
    782     return partial_block_mask, full_block_mask

File ~/work/anaconda3/lib/python3.9/site-packages/torch/nn/attention/flex_attention.py:755, in create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device, _compile)
    753     mask_mod = mod_fn
    754     mask_mod = _vmap_for_bhqkv(mask_mod, prefix=())
--> 755     mask = mask_mod(b, h, m, n)
    756     return mask
    757 else:

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    202 def wrapped(*args, **kwargs):
--> 203     return vmap_impl(
    204         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    205     )

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    320     return _chunked_vmap(
    321         func,
    322         flat_in_dims,
   (...)
    327         **kwargs,
    328     )
    330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
    332     func,
    333     batch_size,
    334     flat_in_dims,
    335     flat_args,
    336     args_spec,
    337     out_dims,
    338     randomness,
    339     **kwargs,
    340 )

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
    476     batched_inputs = _create_batched_inputs(
    477         flat_in_dims, flat_args, vmap_level, args_spec
    478     )
--> 479     batched_outputs = func(*batched_inputs, **kwargs)
    480     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    202 def wrapped(*args, **kwargs):
--> 203     return vmap_impl(
    204         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    205     )

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    320     return _chunked_vmap(
    321         func,
    322         flat_in_dims,
   (...)
    327         **kwargs,
    328     )
    330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
    332     func,
    333     batch_size,
    334     flat_in_dims,
    335     flat_args,
    336     args_spec,
    337     out_dims,
    338     randomness,
    339     **kwargs,
    340 )

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
    476     batched_inputs = _create_batched_inputs(
    477         flat_in_dims, flat_args, vmap_level, args_spec
    478     )
--> 479     batched_outputs = func(*batched_inputs, **kwargs)
    480     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)

    [... skipping similar frames: _flat_vmap at line 479 (1 times), vmap_impl at line 331 (1 times), vmap.<locals>.wrapped at line 203 (1 times)]

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    202 def wrapped(*args, **kwargs):
--> 203     return vmap_impl(
    204         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    205     )

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    320     return _chunked_vmap(
    321         func,
    322         flat_in_dims,
   (...)
    327         **kwargs,
    328     )
    330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
    332     func,
    333     batch_size,
    334     flat_in_dims,
    335     flat_args,
    336     args_spec,
    337     out_dims,
    338     randomness,
    339     **kwargs,
    340 )

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
    476     batched_inputs = _create_batched_inputs(
    477         flat_in_dims, flat_args, vmap_level, args_spec
    478     )
--> 479     batched_outputs = func(*batched_inputs, **kwargs)
    480     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)

Input In [3], in document_causal_mask(b, h, q_idx, kv_idx)
      7 def document_causal_mask(b, h, q_idx, kv_idx):
      8     causal_mask = q_idx >= kv_idx
----> 9     document_mask = document_id[q_idx] == document_id[kv_idx]
     10     return causal_mask & document_mask

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_higher_order_ops/flex_attention.py:84, in TransformGetItemToIndex.__torch_function__(self, func, types, args, kwargs)
     82     index_args = pytree.tree_leaves(args[1])
     83     if all(isinstance(x, torch.Tensor) for x in index_args):
---> 84         return torch.ops.aten.index(args[0], index_args)
     85 return func(*args, **(kwargs or {}))

File ~/work/anaconda3/lib/python3.9/site-packages/torch/_ops.py:1116, in OpOverloadPacket.__call__(self, *args, **kwargs)
   1114 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
   1115     return _call_overload_packet_from_python(self, args, kwargs)
-> 1116 return self._op(*args, **(kwargs or {}))

IndexError: index 100 is out of bounds for dimension 0 with size 100
xidulu commented 3 weeks ago

Update: Explicitly stating BLOCK_SIZE=100 can fix the problem... but is it required.?

drisspg commented 3 weeks ago

Could you try one more thing: pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118

I think that we have fixed this in nightly and that it hasn't gotten to a major release yet

xidulu commented 3 weeks ago

@drisspg It's working now, thanks!

drisspg commented 3 weeks ago

Issues is Fixed on nightly