meta-llama / llama

Inference code for Llama models
Other
56.53k stars 9.58k forks source link

Grouped-Query Attention #384

Open 19h opened 1 year ago

19h commented 1 year ago

Hello Meta GenAI team (cc @ruanslv),

With regards to the 70B model, I'm currently looking into the implementation of the GQA architecture -- specifically after noticing the 8192 x 1024 layer shapes, I was trying to identify the conditional GQA parts in your reference implementation but couldn't pin it down.

Given that there are some conditions that smell suspiciously GQA-related, could you please elaborate on the parts of the implementation that enable this architecture specifically for the 34B / 70B models?

Thanks

19h commented 1 year ago

Out of impatience I asked Claude 2 about the differences between Implementation A (LLaMA 1) and Implementation B (LLaMA 2):

Then I explained the concept of GQA and asked it for the parts enabling GQA:

Is this a proper assessment by Claude?

byildiz commented 1 year ago

Hi,

I think this image is a good summary of GQA:

Screenshot 2023-07-19 at 04 39 22

As far as I understand GQA reduces cache sizes for keys and values by n_heads / n_kv_heads times. Because they have reduced sizes, they are later needed to be back to actual size. This is done by simply repeating. This repetition process is done in repeat_kv function at:

https://github.com/facebookresearch/llama/blob/4d92db8a1db6c7f663252bf3477d2c4b8bad2385/llama/model.py#L77

As an example:

>>> x = torch.rand(1, 2, 3, 4)
>>> x
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])
>>> n_rep = 2
>>> bs, slen, n_kv_heads, head_dim = x.shape
>>> x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])

The only major change that I notice is this repetition. I hope this helps you.

missflash commented 1 year ago

Hi,

I think this image is a good summary of GQA:

Screenshot 2023-07-19 at 04 39 22

As far as I understand GQA reduces cache sizes for keys and values by n_heads / n_kv_heads times. Because they have reduced sizes, they are later needed to be back to actual size. This is done by simply repeating. This repetition process is done in repeat_kv function at:

https://github.com/facebookresearch/llama/blob/4d92db8a1db6c7f663252bf3477d2c4b8bad2385/llama/model.py#L77

As an example:

>>> x = torch.rand(1, 2, 3, 4)
>>> x
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])
>>> n_rep = 2
>>> bs, slen, n_kv_heads, head_dim = x.shape
>>> x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])

The only major change that I notice is this repetition. I hope this helps you.

Thanks for the great explanation! :)