Open huseinzol05 opened 1 month ago
This is to replicate the original jax implementation, https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/blockwise-vanilla-attention-causal.ipynb
I just simply generate global multiplier mask (lazy to do addition),
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()
attn_bias_blocks = torch.chunk(temp_mask, chunk_size)
attn_bias_block = attn_bias_blocks[0]
seq_chunk = Q.shape[0] // chunk_size
attn_bias_b = attn_bias_block[:, no * seq_chunk: (no + 1) * seq_chunk]
scores = torch.matmul(Q_block, K_block.T) * attn_bias_b
While original jax generate addition mask during blockwise,
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
bias, segment_ids, deterministic, attn_dropout, attn_pdrop, causal,
dtype, query_chunk_idx, key_chunk_idx):
query_offset = query_chunk_idx * query_chunk_size
key_offset = key_chunk_idx * key_chunk_size
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
if bias is not None:
chunk_bias = lax.dynamic_slice(
bias,
start_indices=(0, 0, 0, key_offset),
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
)
if segment_ids is not None:
q_segment_ids = lax.dynamic_slice(
segment_ids,
start_indices=(0, query_offset),
slice_sizes=(segment_ids.shape[0], query_chunk_size)
)
k_segment_ids = lax.dynamic_slice(
segment_ids,
start_indices=(0, key_offset),
slice_sizes=(segment_ids.shape[0], key_chunk_size)
)
segment_ids_mask = q_segment_ids[:, :, None] != k_segment_ids[:, None, :]
segment_ids_mask = segment_ids_mask[:, None] # B1QK
segment_ids_bias = segment_ids_mask * jnp.finfo(dtype).min
chunk_bias += segment_ids_bias
if causal:
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
offset = query_offset - key_offset
query_idx += offset
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
if not deterministic and attn_pdrop > 0.0:
attn_dropout_slice = lax.dynamic_slice(
attn_dropout,
start_indices=(0, 0, query_offset, key_offset),
slice_sizes=(
*attn_dropout.shape[:2],
min(attn_dropout.shape[-2], query_chunk_size),
min(attn_dropout.shape[-1], key_chunk_size),
),
)
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
return chunk_bias.astype(dtype)
_chunk_bias_fn = partial(
_chunk_attention_bias,
query_chunk_size, key_chunk_size, bias, segment_ids, deterministic,
attn_dropout, attn_pdrop, causal, dtype)
bias_chunk = _chunk_bias_fn(q_chunk_idx_start + q_chunk_idx, k_chunk_idx_start + k_chunk_idx)
any comment @zhuzilin ?
hmm... I'm not sure what you are aiming at. If you just want to be sure that this implementation supports causal mask, you can try running the code in the test folder.
The code in the repo is not a step by step transfer from the origin jax implementation and I actually haven't read that before...
Sorry, I just want to verify the causal masking because based on the code,
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k)
next_v: torch.Tensor = comm.send_recv(v)
comm.commit()
if not causal or step <= comm.rank:
params = get_default_args(_flash_attn_forward).copy()
params.update(
{
"q": q,
"k": k,
"v": v,
"dropout_p": dropout_p,
"softmax_scale": softmax_scale,
"causal": causal and step == 0,
"window_size": window_size,
"alibi_slopes": alibi_slopes,
"return_softmax": True and dropout_p > 0,
}
)
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
Let say I have qkv size [L, dim], [10, 100] with causal mask [10, 10],
tensor([[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True]])
Now I chunk on sequence dimension to 2 devices with each qkv size [5, 100] and causal [5, 10],
device 0: [5, 100] mask [5, 10]
tensor([[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False]])
device 1: [5, 100] mask [5, 10]
tensor([[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True]])
On device 0, so the blockwise,
q0k0mask[:, 0:5]v0 + q0k1mask[:, 5:10]v1
where mask[:, 0:5],
tensor([[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]])
where mask[:, 5:10],
tensor([[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False]])
On device 1, so the blockwise,
q1k0mask[0:5]v0 + q1k1mask[5:10]v1
where mask[:, 0:5],
tensor([[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]])
where mask[:, 5:10],
tensor([[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]])
Now back to the forward code,
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k)
next_v: torch.Tensor = comm.send_recv(v)
comm.commit()
if not causal or step <= comm.rank:
params = get_default_args(_flash_attn_forward).copy()
params.update(
{
"q": q,
"k": k,
"v": v,
"dropout_p": dropout_p,
"softmax_scale": softmax_scale,
"causal": causal and step == 0,
"window_size": window_size,
"alibi_slopes": alibi_slopes,
"return_softmax": True and dropout_p > 0,
}
)
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
let say currently at device 0, and world_size == 2 so device 0 will calculate q0k0v0 + q0k1v1.
when step == 0,
causal is true based on causal and step == 0
, flash attention will generate causal mask lower triangle but we need,
tensor([[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]])
when step == 1,
causal is false based on causal and step == 0
, flash attention will do full attention but we need,
tensor([[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False]])
Same goes to device 1.
@zhuzilin Hopefully you can understand what I am trying to share, lol
hmm... there will never be a rectangular mask as all the k and q chunks will have the same sequence length. And there will also not be a mask with all False
, which actually means doing no calculation...
regardless, calculation of qk-mask-v still happpened, it just produced really really small value and merging later will produced correct results. Whats your thought on this?
Hi @zhuzilin, follow up from https://github.com/zhuzilin/ring-flash-attention/issues/15
I just wanted to verify the causal, and I simply use loop because I dont have multigpus, but it should be working, when I do causal using your ring logic, the argmax accuracy is super low, but when I do non causal, accuracy is almost perfect 100%, you can check the notebook at https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/flash-ring-attention-causal.ipynb
From what I understand, let say, I got 2 devices and seqlen of 100k, partitioned to 2, 100k // 2 = 50k 50k, so,
each 50k seq len, device 0: 50k q0k0v0 device 1: 50k q1k1v1
So the blockwise attention calculation, device 0: 50k q0k0v0 + 50k q0k1v1 device 1: 50k q1k0v0 + 50k q1k1v1
(+) denoted as blockwise attention.
For causal base, attention mask is necessary, so the attention mask originally is [100k, 100k] and attention mask we must chunk properly, to become mask0 = [50k, 100k] and mask1 = [50k, 100k], so the blockwise attention calculation,
device 0: 50k (q0k0 mask0[:, 0:50k])v0 + 50k q0k1v1 mask0[:, 50k:100k] device 1: 50k (q1k0 mask1[:, 0:50k])v0 + 50k q1k1v1 mask1[:, 50k:100k]
You can see this slicing from original https://github.com/forhaoliu/ringattention/blob/main/ringattention/ringattention_pallas_tpu.py#L61
Correct me if im wrong here, thanks!