Open 19h opened 1 year ago
Out of impatience I asked Claude 2 about the differences between Implementation A (LLaMA 1) and Implementation B (LLaMA 2):
Then I explained the concept of GQA and asked it for the parts enabling GQA:
_The key difference between Implementation A and B that enables Grouped Query Attention is having separate n_kv_heads and nheads arguments.
_In Implementation B, n_kvheads allows having fewer key/value projections than query projections. This enables queries to be grouped into fewer key/value heads, which is the core idea behind GQA.
_The repeat_kv function is used to repeat the keys/values to match the number of query heads. This allows compatibility with a standard multi-head attention model trained with full nheads.
_During inference, n_kv_heads can be set to an intermediate value between 1 (MQA) and nheads (full MHA) to trade off quality and speed.
_In summary, the separate n_kv_heads argument and repeatkv function provide the flexibility needed to easily switch to GQA after training a full MHA model. The other changes in Implementation B enable training a high quality MHA model in the first place.
Is this a proper assessment by Claude?
Hi,
I think this image is a good summary of GQA:
As far as I understand GQA reduces cache sizes for keys and values by n_heads / n_kv_heads
times. Because they have reduced sizes, they are later needed to be back to actual size. This is done by simply repeating. This repetition process is done in repeat_kv
function at:
As an example:
>>> x = torch.rand(1, 2, 3, 4)
>>> x
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
[0.3441, 0.1733, 0.3397, 0.5518],
[0.2516, 0.6651, 0.1699, 0.0092]],
[[0.9057, 0.8071, 0.6634, 0.5770],
[0.1865, 0.2643, 0.8765, 0.8715],
[0.3958, 0.9162, 0.7325, 0.9555]]]])
>>> n_rep = 2
>>> bs, slen, n_kv_heads, head_dim = x.shape
>>> x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
[0.1269, 0.8517, 0.4630, 0.1814],
[0.3441, 0.1733, 0.3397, 0.5518],
[0.3441, 0.1733, 0.3397, 0.5518],
[0.2516, 0.6651, 0.1699, 0.0092],
[0.2516, 0.6651, 0.1699, 0.0092]],
[[0.9057, 0.8071, 0.6634, 0.5770],
[0.9057, 0.8071, 0.6634, 0.5770],
[0.1865, 0.2643, 0.8765, 0.8715],
[0.1865, 0.2643, 0.8765, 0.8715],
[0.3958, 0.9162, 0.7325, 0.9555],
[0.3958, 0.9162, 0.7325, 0.9555]]]])
The only major change that I notice is this repetition. I hope this helps you.
Hi,
I think this image is a good summary of GQA:
As far as I understand GQA reduces cache sizes for keys and values by
n_heads / n_kv_heads
times. Because they have reduced sizes, they are later needed to be back to actual size. This is done by simply repeating. This repetition process is done inrepeat_kv
function at:As an example:
>>> x = torch.rand(1, 2, 3, 4) >>> x tensor([[[[0.1269, 0.8517, 0.4630, 0.1814], [0.3441, 0.1733, 0.3397, 0.5518], [0.2516, 0.6651, 0.1699, 0.0092]], [[0.9057, 0.8071, 0.6634, 0.5770], [0.1865, 0.2643, 0.8765, 0.8715], [0.3958, 0.9162, 0.7325, 0.9555]]]]) >>> n_rep = 2 >>> bs, slen, n_kv_heads, head_dim = x.shape >>> x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim) tensor([[[[0.1269, 0.8517, 0.4630, 0.1814], [0.1269, 0.8517, 0.4630, 0.1814], [0.3441, 0.1733, 0.3397, 0.5518], [0.3441, 0.1733, 0.3397, 0.5518], [0.2516, 0.6651, 0.1699, 0.0092], [0.2516, 0.6651, 0.1699, 0.0092]], [[0.9057, 0.8071, 0.6634, 0.5770], [0.9057, 0.8071, 0.6634, 0.5770], [0.1865, 0.2643, 0.8765, 0.8715], [0.1865, 0.2643, 0.8765, 0.8715], [0.3958, 0.9162, 0.7325, 0.9555], [0.3958, 0.9162, 0.7325, 0.9555]]]])
The only major change that I notice is this repetition. I hope this helps you.
Thanks for the great explanation! :)
Hello Meta GenAI team (cc @ruanslv),
With regards to the 70B model, I'm currently looking into the implementation of the GQA architecture -- specifically after noticing the 8192 x 1024 layer shapes, I was trying to identify the conditional GQA parts in your reference implementation but couldn't pin it down.
Given that there are some conditions that smell suspiciously GQA-related, could you please elaborate on the parts of the implementation that enable this architecture specifically for the 34B / 70B models?
Thanks