Closed xidulu closed 3 weeks ago
Hmm what version of PyTorch are you using?
I just ran this and got:
BlockMask(shape=(1, 1, 128, 128), sparsity=0.00%,
(0, 0)
██
)
2.5.1+cu118 is the version (which is the latest release version I believe)!
Also see below for the complete stack trace:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Input In [3], in <cell line: 12>()
9 document_mask = document_id[q_idx] == document_id[kv_idx]
10 return causal_mask & document_mask
---> 12 mask = create_block_mask(document_causal_mask, 1, 1, 100, 100, "cpu")
13 print(mask)
File ~/work/anaconda3/lib/python3.9/site-packages/torch/nn/attention/flex_attention.py:850, in create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, BLOCK_SIZE, _compile)
848 inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
849 with TransformGetItemToIndex():
--> 850 partial_block_mask, full_block_mask = inner_func(
851 mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE
852 )
853 block_mask = _create_sparse_block_from_block_mask(
854 (partial_block_mask, full_block_mask), mask_mod
855 )
856 return block_mask
File ~/work/anaconda3/lib/python3.9/site-packages/torch/nn/attention/flex_attention.py:775, in _create_block_mask_inner(mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE)
761 def _create_block_mask_inner(
762 mask_mod: Callable,
763 B: int,
(...)
769 Q_BLOCK_SIZE: int,
770 ):
771 r"""Work around for being unable to instantiate __torch_function__ mode under compile.
772 `create_block_mask` will compile this inner function and wrap the call to this
773 with the __torch_function__ mode.
774 """
--> 775 mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
776 partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
777 mask_tensor,
778 KV_BLOCK_SIZE=KV_BLOCK_SIZE,
779 Q_BLOCK_SIZE=Q_BLOCK_SIZE,
780 separate_full_blocks=True,
781 )
782 return partial_block_mask, full_block_mask
File ~/work/anaconda3/lib/python3.9/site-packages/torch/nn/attention/flex_attention.py:755, in create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device, _compile)
753 mask_mod = mod_fn
754 mask_mod = _vmap_for_bhqkv(mask_mod, prefix=())
--> 755 mask = mask_mod(b, h, m, n)
756 return mask
757 else:
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
202 def wrapped(*args, **kwargs):
--> 203 return vmap_impl(
204 func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
205 )
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
320 return _chunked_vmap(
321 func,
322 flat_in_dims,
(...)
327 **kwargs,
328 )
330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
332 func,
333 batch_size,
334 flat_in_dims,
335 flat_args,
336 args_spec,
337 out_dims,
338 randomness,
339 **kwargs,
340 )
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
476 batched_inputs = _create_batched_inputs(
477 flat_in_dims, flat_args, vmap_level, args_spec
478 )
--> 479 batched_outputs = func(*batched_inputs, **kwargs)
480 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
202 def wrapped(*args, **kwargs):
--> 203 return vmap_impl(
204 func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
205 )
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
320 return _chunked_vmap(
321 func,
322 flat_in_dims,
(...)
327 **kwargs,
328 )
330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
332 func,
333 batch_size,
334 flat_in_dims,
335 flat_args,
336 args_spec,
337 out_dims,
338 randomness,
339 **kwargs,
340 )
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
476 batched_inputs = _create_batched_inputs(
477 flat_in_dims, flat_args, vmap_level, args_spec
478 )
--> 479 batched_outputs = func(*batched_inputs, **kwargs)
480 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
[... skipping similar frames: _flat_vmap at line 479 (1 times), vmap_impl at line 331 (1 times), vmap.<locals>.wrapped at line 203 (1 times)]
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
202 def wrapped(*args, **kwargs):
--> 203 return vmap_impl(
204 func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
205 )
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
320 return _chunked_vmap(
321 func,
322 flat_in_dims,
(...)
327 **kwargs,
328 )
330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
332 func,
333 batch_size,
334 flat_in_dims,
335 flat_args,
336 args_spec,
337 out_dims,
338 randomness,
339 **kwargs,
340 )
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
476 batched_inputs = _create_batched_inputs(
477 flat_in_dims, flat_args, vmap_level, args_spec
478 )
--> 479 batched_outputs = func(*batched_inputs, **kwargs)
480 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
Input In [3], in document_causal_mask(b, h, q_idx, kv_idx)
7 def document_causal_mask(b, h, q_idx, kv_idx):
8 causal_mask = q_idx >= kv_idx
----> 9 document_mask = document_id[q_idx] == document_id[kv_idx]
10 return causal_mask & document_mask
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_higher_order_ops/flex_attention.py:84, in TransformGetItemToIndex.__torch_function__(self, func, types, args, kwargs)
82 index_args = pytree.tree_leaves(args[1])
83 if all(isinstance(x, torch.Tensor) for x in index_args):
---> 84 return torch.ops.aten.index(args[0], index_args)
85 return func(*args, **(kwargs or {}))
File ~/work/anaconda3/lib/python3.9/site-packages/torch/_ops.py:1116, in OpOverloadPacket.__call__(self, *args, **kwargs)
1114 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
1115 return _call_overload_packet_from_python(self, args, kwargs)
-> 1116 return self._op(*args, **(kwargs or {}))
IndexError: index 100 is out of bounds for dimension 0 with size 100
Update: Explicitly stating BLOCK_SIZE=100 can fix the problem... but is it required.?
Could you try one more thing: pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
I think that we have fixed this in nightly and that it hasn't gotten to a major release yet
@drisspg It's working now, thanks!
Issues is Fixed on nightly
This would raise exception
which I believe is from the indexed operation
document_id[q_idx]