huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.69k stars 26.94k forks source link

GQA Llama 13B slower than Llama 13B without GQA #28425

Open Adonai02 opened 10 months ago

Adonai02 commented 10 months ago

Feature request

It would be nice if when I choose different key_value_heads (key_value_heads < attention_heads) on config's model, automatically the attn weights were computed by mean pooling. Right now, if I do this, it gives me the next error.

key_value_heads = 4

image

Motivation

Make models faster, e.g Llama 2 13B, Llama 7B, Mistral 7B etc.

Your contribution

I tried to do a simple implementation. But it gives me inconsistent results. GQA model is slower than No GQA model.

from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention
from copy import deepcopy
import torch

def split_attention_to_heads(input_tensor, num_splits):
    # Get the shape of the input tensor
    rows, cols = input_tensor.shape

    # Check if the number of rows is divisible by the number of splits
    if rows % num_splits != 0:
        raise ValueError("Number of rows is not divisible by the number of splits")

    # Calculate the number of rows in each split

    # Use chunk to split the tensor along the rows
    split_tensors = input_tensor.chunk(num_splits, dim=0)

    return split_tensors

def average_heads(tensor_tuple, group_size, dtype):
    # Initialize an empty list to store the averaged tensors
    averaged_tensors = []

    # Iterate through the tuple and average consecutive groups
    for i in range(0, len(tensor_tuple), group_size):
        # Take a group of tensors
        tensor_group = tensor_tuple[i:i + group_size]

        # Calculate the mean along dimension 0
        averaged_tensor = torch.mean(torch.stack(tensor_group), dim=0, dtype=dtype)

        # Append the averaged tensor to the list
        averaged_tensors.append(averaged_tensor)

    # Convert the list of averaged tensors to a tuple
    averaged_tensors_tuple = tuple(averaged_tensors)

    return averaged_tensors_tuple

def convert_wts_to_gqa(attention_module: torch.nn.Module , model_configuration: LlamaConfig):
    attentions_wts = attention_module.state_dict().copy()
    num_heads = model_configuration.num_attention_heads
    gqa_groups = num_heads // model_configuration.num_key_value_heads
    for name_wts in list(attentions_wts.keys()):
        if ("k_proj" in name_wts) or ("v_proj" in name_wts):
            tensor_to_convert = attentions_wts[name_wts].clone()
            torch_dtype = tensor_to_convert.dtype
            attn_heads = split_attention_to_heads(tensor_to_convert, num_splits=num_heads)
            gqa_tensors_grouped = average_heads(attn_heads, gqa_groups, dtype=torch_dtype)
            gqa_tensors_grouped = torch.cat(gqa_tensors_grouped)
            attentions_wts[name_wts] = gqa_tensors_grouped
            del tensor_to_convert
    return attentions_wts

def convert_llama_to_gqa(module: torch.nn.Module, llama_config_from_hf: LlamaConfig, inplace: bool = False):
    if isinstance(module, LlamaAttention):
        wts_gqa = convert_wts_to_gqa(attention_module=module, model_configuration=llama_config_from_hf)
        llama_atention_gqa = LlamaAttention(llama_config_from_hf, layer_idx=module.layer_idx)
        llama_atention_gqa.half()
        llama_atention_gqa.load_state_dict(wts_gqa)
        return llama_atention_gqa

    out = module if inplace else deepcopy(module)
    for name, child in out.named_children():
        out._modules[name] = convert_llama_to_gqa(child, llama_config_from_hf=llama_config_from_hf, inplace=True)
    return out

from transformers import AutoConfig

configuration_llama = AutoConfig.from_pretrained("meta-llama/Llama-2-13b-chat-hf")
configuration_llama.num_key_value_heads = 4

llama_gqa = convert_llama_to_gqa(llama, configuration_llama)

Results GQA LLAMA

image

NO GQA LLAMA

image

I don't know if I'm misunderstanding something, please let me know if you can see something I can't

ArthurZucker commented 10 months ago

That would be nice but a bit outside the scope of transformers! Would be nice if you have a working example! What I recommend is to register a load_state_dict hook that converts the checkpoints on the fly. The benchmark should run on different num kv heads as some shapes might be less optimal ? That would be my intuitiion. Also a single head (MQA) should be always faster than MHA

Nexesenex commented 10 months ago

If I understand correctly, here is displayed an attempt to implement GQA on a non GQA Llama 2 13b model? If that's the case, and despite the slight loss of performance observed, does the context size in VRAM gets diminished as GQA allows, and is the perplexity of the model affected? If that's not the case, sorry for misunderstanding!