Open zhanglei1172 opened 10 months ago
Hi, @zhanglei1172 The standard input format is (batches, nheads, seqlen, d).
Note N_CTX_2 (batches x hiddens x seqence_length) is introduced in the PR to support Hopper TMA. I think the kernel you used is mixing use of both :
you loaded K_block_ptr with shape (HIDDEN_DIM, N_CTX_2) in column major layout (Note , the transpose is not necessary if you actually checked the latest implementation, you can just load k slice in row major layout and make an inner dot between q tile and k tile)
you loaded V_block_ptr of shape ( N_CTX_2, HIDDEN_DIM) in row major layout
however, you load Q_block_ptr : (N_CTX, HIDDEN_DIM)
Look at cuda blocks created :
grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1)
You have created 1nheads sequence_length / BLOCK_M blocks
# Q is like a two dimension matrix, with rows **(1 x nheads x sequence_length/BLOCK_M)** and columns **emb_d**, s.t. Q=[q1, q2, ...]_T
# K is like another tow dimension matrix (we don't need to transpose it) with rows **(1 x nheads x sequence_length/BLOCK_N)** and columns **emb_d**, s.t. K = [k1, k2, ...]_T
to calculate causal attention we load
q_i
- iter 0 : k1, v1
- iter 1 : k2, v2
...
- iter hi : k_hi, v_hi (BLOCK_N, emb_d)
Can you check the source code version you used and explain N_CTX_2 in your case?
Hi, I use triton==2.1.0. and N_CTX_2 means the sequence_length of K/V. N_CTX means the sequence_length of Q.
I check the tutorials: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
and https://github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py#L42 . Both codes use transpose.
And I try to change my second version above:
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(N_CTX_2, HIDDEN_DIM),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
qk += tl.dot(q, tl.trans(k), allow_tf32=True)
Results:
max diff: 0.039794921875
max diff: 0.0565185546875
max diff: 0.06787109375
max diff: 0.04058837890625
@zhanglei1172 The code you referred is updated version with patch PR#2336. The definition of N_CTX_2 is here:
q.shape[0] q.shape[1] q.shape[2]
It is not the sequence_length of K or V or more precisely it treats Q, K, V as 2 dimension matrix, (-1, d), (d, -1), (d, -1), where axis -1 contains values repeated q.shape[0] q.shape[1] q.shape[2] times along sequence length direction.
So I think the mixed version of Flash attention v2 (online softmax with memory efficient attention, with Q loaded first in outter loop) and the attention with the support of Hoper TMA is used here.
@yiakwy-xpu-ml-framework-team You mentioned that N_CTX_2 (batches x hiddens x seqence_length) is introduced in the PR to support Hopper TMA, but the version of triton I'm using(or reference) doesn't reference the code containing this PR. Instead, the reason I introduced N_CTX_2(I set it up myself. I didn't refer the code on that PR about Hoper TMA) is for use in scenarios dealing with cross attention (Q and K/V have inconsistent sequence lengths). The code I used intends to differ from the original attention in two main ways: permulation of input shape and the sequence lengths of Q and K/V are different
So I ultimately want to make sure if the current Triton can't support this kind of permulation or cross attention, and if it does, then can I modify the original code to get the results correctly?
This version give correct output, but when I change the qkv layout from (seqlen, nheads, d) to (nheads, seqlen, d), I got wrong result.
And it is very strange that when q_ctx is 4095 and 4096, the error difference is very large (' 0.0134429931640625 'vs.' 4.57763671875e-05 ').