mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.49k stars 835 forks source link

how to explain Attention that input QKV tensor # xformers requires (B=1, S, H, D) #65

Closed dhcode-cpp closed 10 months ago

dhcode-cpp commented 10 months ago

My data batch size = 3, windows_size = 3, the input like is

sequences = ["11 12 13 14 15", "21 22 23 24 25 26 27", "31 32"]

I have two questions when I debugging mistral model;

First, 3 batch sequences would be flat as a one sequence [5, 7, 2] -> tensor like [5+7+2, 1]?

Second, If first things is true, how do we calculate attention?

  1. how to explain xformers requires (B=1, S, H, D), if we make 3 batch as 1 sequence, we would calculate cross batch attention?
  2. we generate 1 token by QKV[1, 17, 4, 128], But 2 step, the 2-dim q is 3, k is 9, how to confirm this output? I think q[q_b1, q_b2, q_b3], k is [k_b1_window1, k_b1_window2, k_b1_window2, ..........]

We print Q/K/V shape before mistral/model.py:

# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]

print('q:',xq.shape)
print('k:',key.shape)
print('v:',val.shape)

# output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)

and print string as following(the layer number is 2, n_kv_head =4 and n_head = 4):

------------------ 0 cur_layer_id : 0 q: torch.Size([1, 17, 4, 128]) k: torch.Size([1, 17, 4, 128]) v: torch.Size([1, 17, 4, 128]) ------------------ 1 cur_layer_id : 1 q: torch.Size([1, 17, 4, 128]) k: torch.Size([1, 17, 4, 128]) v: torch.Size([1, 17, 4, 128]) ------------------ 0 cur_layer_id : 0 q: torch.Size([1, 3, 4, 128]) k: torch.Size([1, 9, 4, 128]) v: torch.Size([1, 9, 4, 128]) ------------------ 1 cur_layer_id : 1 q: torch.Size([1, 3, 4, 128]) k: torch.Size([1, 9, 4, 128]) v: torch.Size([1, 9, 4, 128])

Mistral is an impressive work, and I'm excited to hear your response. Thank you very much!

dhcode-cpp commented 10 months ago

I found Xformers API use block-diagonal-mask that can make cross-batch independently

Ref: https://facebookresearch.github.io/xformers/components/ops.html
CLASSxformers.ops.fmha.attn_bias.BlockDiagonalMask

image

That input is [1, batch1_len+batch2_len+batch3_len, H, K]

image