huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.67k stars 26.22k forks source link

Implement SWA (Sliding Window Attention) for Llama-2 7B #28915

Closed gangaraju09 closed 7 months ago

gangaraju09 commented 7 months ago

Feature request

Hi,

I have access to Llama-2 7B weights and am wondering how to write a wrapper which replaces the standard vanilla attention (or Grouped Attention) present in Llama-2 to SWA (as explained in Longformer (https://arxiv.org/abs/2004.05150v2) and implemented in Mistral-7B - https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L60)

Ideally, it should be:

vanilla_model = AutoModelForCausalLM(checkpoint)
swa_model = AutoModelForCausalLM(checkpoint, attention_type='swa')

Motivation

One can load the weights using AutoModelForCausalLM and instead of using the standard Attention block, this has to use the SWAClass. This ideally can help for faster inference.

P.S: Most likely, a standard drop-in replacement of SWA from Vanilla might drop in performance! So, if there's any suggestion on how to recover the model's performance after the replacement, that would be super helpful!

Alternatively, if this is already implemented, please share the resources! I was unable to find any blogs/code-base except (https://github.com/lucidrains/local-attention)

Your contribution

I can contribute to the PR if there's some help in understanding how to proceed with this!

ArthurZucker commented 7 months ago

Hey! This can be easily done with a custom SlidingWindowLlamaAttention that you register to the LLAMA_ATTENTION_CLASSES 😉

NamburiSrinath commented 7 months ago

Hey @ArthurZucker,

Do you think this feature is actually useful? If we just replace the standard attention with SWA as a drop-in replacement, without any finetuning; won't the performance drop?

What are your thoughts?

ArthurZucker commented 7 months ago

I don't think perf should drop that much, as it would be kind of like SinkCache in a way. But we don't know until we try! Closing as completed since you can use the LLAMA_ATTENTION_CLASSES 😉

NamburiSrinath commented 6 months ago

Hi @ArthurZucker,

Here's what I tried out (in transformers/models/llama/modeling_llama.py)

def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0,
    window_size: int=0):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    # standard causal attention!
    # mask_cond = torch.arange(mask.size(-1), device=device)
    # mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

    # slided window attention mask!!
    for i in range(tgt_len):
        start = max(0, i - window_size + 1)
        end = min(tgt_len, i + 1)
        mask[i, start:end] = 0
    mask = mask.to(dtype)
    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

Now this function creates a mask that slides according to the window size. Assuming the window_size is 3 and length of input is 6, the mask looks like:

[0, -65534, -65534, -65534, -65534, -65534],
[0, 0, -65534, -65534, -65534, -65534],
[0, 0, 0, -65534, -65534, -65534],
[-65534, 0, 0, 0, -65534, -65534],
[-65534, -65504, 0, 0, 0, -65504],
[-65504, -65504, -65504, 0, 0, 0]]

where as the standard causal mask attention looks like the below:

[0, -65534, -65534, -65534, -65534, -65534],
[0, 0, -65534, -65534, -65534, -65534],
[0, 0, 0, -65534, -65534, -65534],
[0, 0, 0, 0, -65534, -65534],
[0, 0, 0, 0, 0, -65504],
[0, 0, 0, 0, 0, 0]]

I am wondering if this change is enough for the plain vanilla SWA implementation. I am not interfering with the position_ids because I recollect they are not changed from the original ones i.e the positions will be referred from original one not the window!

Please share your thoughts on this, really appreciate it!

NamburiSrinath commented 6 months ago

Hi @ArthurZucker,

I tried this drop-in replacement on a sample of "BillSum" dataset and the normal attention performs way better compared to the drop-in SWA! The drop-in doesn't even produce a fluent text, so I am not sure if this implementation is actually correct or if I am missing some details!

On the other hand, I fine-tuned the Llama-7B with Guanaco with SWA for 5 epochs and it is generating some text (still gibberish) compared to the drop-in replacement, but still it is also way off compared to the normal attention!

Here are few observations:

  1. The normal attention is faster compared to the SWA ones (measured using time.time()), but the theory says otherwise for the long text!! (and billsum is a long-text!)
  2. For long text summarization, all the 3 variations (vanilla attention, drop-in SWA and finetuned SWA) produces mostly gibberish, but for some reason the SWA suffers significantly higher (gibberish/no-text) compared to vanilla model!

Here's the code snippet used for generating the output:

      output = model.generate(input_ids, 
                              max_new_tokens=300,  
                              num_beams=3,
                              do_sample=do_sample_flag, #which I set to True and False to see different effects, usually sampling helps)
                              top_k=100,
                              temperature=10.0,
                              no_repeat_ngram_size=5,
                              attention_mask=attn_masks)

And here's the preprocessing function

prefix = "Summarize the following bill. Focus your summary on the most important aspects of the bill. You do not have to summarize everything. Particularly focus on questions related to appropriation and the effects and impacts of the bill. However, you do not need to go into complex details, it is acceptable to provide ranges. Use active verbs to describe the bill, like 'amends' or 'changes'. Do not use ambivalent verbs like 'proposes' or 'suggests.'"

def preprocess_function(examples):
   inputs = [prefix + doc for doc in examples['text']]
   model_inputs = tokenizer(inputs, padding=True, truncation=False, 
                              max_length=1024, return_tensors='pt')
   model_inputs["labels"] = examples["summary"]
   return model_inputs

I've referred to Summarization article from here for most of the details: https://huggingface.co/docs/transformers/tasks/summarization#inference

Incase you have any thoughts, feel free to let me know! TIA :)