zhuzilin / ring-flash-attention

Ring attention implementation with flash attention
MIT License
592 stars 48 forks source link

verify causal masking #54

Open huseinzol05 opened 1 month ago

huseinzol05 commented 1 month ago

Hi @zhuzilin, follow up from https://github.com/zhuzilin/ring-flash-attention/issues/15

I just wanted to verify the causal, and I simply use loop because I dont have multigpus, but it should be working, when I do causal using your ring logic, the argmax accuracy is super low, but when I do non causal, accuracy is almost perfect 100%, you can check the notebook at https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/flash-ring-attention-causal.ipynb

From what I understand, let say, I got 2 devices and seqlen of 100k, partitioned to 2, 100k // 2 = 50k 50k, so,

each 50k seq len, device 0: 50k q0k0v0 device 1: 50k q1k1v1

So the blockwise attention calculation, device 0: 50k q0k0v0 + 50k q0k1v1 device 1: 50k q1k0v0 + 50k q1k1v1

(+) denoted as blockwise attention.

For causal base, attention mask is necessary, so the attention mask originally is [100k, 100k] and attention mask we must chunk properly, to become mask0 = [50k, 100k] and mask1 = [50k, 100k], so the blockwise attention calculation,

device 0: 50k (q0k0 mask0[:, 0:50k])v0 + 50k q0k1v1 mask0[:, 50k:100k] device 1: 50k (q1k0 mask1[:, 0:50k])v0 + 50k q1k1v1 mask1[:, 50k:100k]

You can see this slicing from original https://github.com/forhaoliu/ringattention/blob/main/ringattention/ringattention_pallas_tpu.py#L61

Correct me if im wrong here, thanks!

huseinzol05 commented 1 month ago

This is to replicate the original jax implementation, https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/blockwise-vanilla-attention-causal.ipynb

I just simply generate global multiplier mask (lazy to do addition),

temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()
attn_bias_blocks = torch.chunk(temp_mask, chunk_size)
attn_bias_block = attn_bias_blocks[0]
seq_chunk = Q.shape[0] // chunk_size
attn_bias_b = attn_bias_block[:, no * seq_chunk: (no + 1) * seq_chunk]
scores = torch.matmul(Q_block, K_block.T) * attn_bias_b

While original jax generate addition mask during blockwise,

def _chunk_attention_bias(query_chunk_size, key_chunk_size,
            bias, segment_ids, deterministic, attn_dropout, attn_pdrop, causal,
            dtype, query_chunk_idx, key_chunk_idx):
    query_offset = query_chunk_idx * query_chunk_size
    key_offset = key_chunk_idx * key_chunk_size
    chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
    if bias is not None:
        chunk_bias = lax.dynamic_slice(
            bias,
            start_indices=(0, 0, 0, key_offset),
            slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
        )

    if segment_ids is not None:
        q_segment_ids = lax.dynamic_slice(
            segment_ids,
            start_indices=(0, query_offset),
            slice_sizes=(segment_ids.shape[0], query_chunk_size)
        )
        k_segment_ids = lax.dynamic_slice(
            segment_ids,
            start_indices=(0, key_offset),
            slice_sizes=(segment_ids.shape[0], key_chunk_size)
        )
        segment_ids_mask = q_segment_ids[:, :, None] != k_segment_ids[:, None, :]
        segment_ids_mask = segment_ids_mask[:, None] # B1QK
        segment_ids_bias = segment_ids_mask * jnp.finfo(dtype).min
        chunk_bias += segment_ids_bias

    if causal:
        query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
        key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
        offset = query_offset - key_offset
        query_idx += offset
        causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
        chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)

    if not deterministic and attn_pdrop > 0.0:
        attn_dropout_slice = lax.dynamic_slice(
            attn_dropout,
            start_indices=(0, 0, query_offset, key_offset),
            slice_sizes=(
                *attn_dropout.shape[:2],
                min(attn_dropout.shape[-2], query_chunk_size),
                min(attn_dropout.shape[-1], key_chunk_size),
            ),
        )
        chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
    return chunk_bias.astype(dtype)
_chunk_bias_fn = partial(
        _chunk_attention_bias,
        query_chunk_size, key_chunk_size, bias, segment_ids, deterministic,
        attn_dropout, attn_pdrop, causal, dtype)
bias_chunk = _chunk_bias_fn(q_chunk_idx_start + q_chunk_idx, k_chunk_idx_start + k_chunk_idx)
huseinzol05 commented 1 month ago

any comment @zhuzilin ?

zhuzilin commented 1 month ago

hmm... I'm not sure what you are aiming at. If you just want to be sure that this implementation supports causal mask, you can try running the code in the test folder.

The code in the repo is not a step by step transfer from the origin jax implementation and I actually haven't read that before...

huseinzol05 commented 1 month ago

Sorry, I just want to verify the causal masking because based on the code,

for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor = comm.send_recv(k)
            next_v: torch.Tensor = comm.send_recv(v)
            comm.commit()

        if not causal or step <= comm.rank:
            params = get_default_args(_flash_attn_forward).copy()
            params.update(
                {
                    "q": q,
                    "k": k,
                    "v": v,
                    "dropout_p": dropout_p,
                    "softmax_scale": softmax_scale,
                    "causal": causal and step == 0,
                    "window_size": window_size,
                    "alibi_slopes": alibi_slopes,
                    "return_softmax": True and dropout_p > 0,
                }
            )
            block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)

Let say I have qkv size [L, dim], [10, 100] with causal mask [10, 10],

tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

Now I chunk on sequence dimension to 2 devices with each qkv size [5, 100] and causal [5, 10],

device 0: [5, 100] mask [5, 10]

tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False]])

device 1: [5, 100] mask [5, 10]

tensor([[ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

On device 0, so the blockwise,

q0k0mask[:, 0:5]v0 + q0k1mask[:, 5:10]v1

where mask[:, 0:5],

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])

where mask[:, 5:10],

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

On device 1, so the blockwise,

q1k0mask[0:5]v0 + q1k1mask[5:10]v1

where mask[:, 0:5],

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])

where mask[:, 5:10],

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])

Now back to the forward code,

for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor = comm.send_recv(k)
            next_v: torch.Tensor = comm.send_recv(v)
            comm.commit()

        if not causal or step <= comm.rank:
            params = get_default_args(_flash_attn_forward).copy()
            params.update(
                {
                    "q": q,
                    "k": k,
                    "v": v,
                    "dropout_p": dropout_p,
                    "softmax_scale": softmax_scale,
                    "causal": causal and step == 0,
                    "window_size": window_size,
                    "alibi_slopes": alibi_slopes,
                    "return_softmax": True and dropout_p > 0,
                }
            )
            block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)

let say currently at device 0, and world_size == 2 so device 0 will calculate q0k0v0 + q0k1v1.

when step == 0, causal is true based on causal and step == 0, flash attention will generate causal mask lower triangle but we need,

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])

when step == 1, causal is false based on causal and step == 0, flash attention will do full attention but we need,

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

Same goes to device 1.

@zhuzilin Hopefully you can understand what I am trying to share, lol

zhuzilin commented 1 month ago

hmm... there will never be a rectangular mask as all the k and q chunks will have the same sequence length. And there will also not be a mask with all False, which actually means doing no calculation...

huseinzol05 commented 1 month ago

regardless, calculation of qk-mask-v still happpened, it just produced really really small value and merging later will produced correct results. Whats your thought on this?