Closed gangaraju09 closed 7 months ago
Hey! This can be easily done with a custom SlidingWindowLlamaAttention
that you register to the LLAMA_ATTENTION_CLASSES
đ
Hey @ArthurZucker,
Do you think this feature is actually useful? If we just replace the standard attention with SWA as a drop-in replacement, without any finetuning; won't the performance drop?
What are your thoughts?
I don't think perf should drop that much, as it would be kind of like SinkCache in a way. But we don't know until we try!
Closing as completed since you can use the LLAMA_ATTENTION_CLASSES
đ
Hi @ArthurZucker,
Here's what I tried out (in transformers/models/llama/modeling_llama.py
)
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0,
window_size: int=0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
# standard causal attention!
# mask_cond = torch.arange(mask.size(-1), device=device)
# mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
# slided window attention mask!!
for i in range(tgt_len):
start = max(0, i - window_size + 1)
end = min(tgt_len, i + 1)
mask[i, start:end] = 0
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
Now this function creates a mask that slides according to the window size. Assuming the window_size is 3 and length of input is 6, the mask looks like:
[0, -65534, -65534, -65534, -65534, -65534],
[0, 0, -65534, -65534, -65534, -65534],
[0, 0, 0, -65534, -65534, -65534],
[-65534, 0, 0, 0, -65534, -65534],
[-65534, -65504, 0, 0, 0, -65504],
[-65504, -65504, -65504, 0, 0, 0]]
where as the standard causal mask attention looks like the below:
[0, -65534, -65534, -65534, -65534, -65534],
[0, 0, -65534, -65534, -65534, -65534],
[0, 0, 0, -65534, -65534, -65534],
[0, 0, 0, 0, -65534, -65534],
[0, 0, 0, 0, 0, -65504],
[0, 0, 0, 0, 0, 0]]
I am wondering if this change is enough for the plain vanilla SWA implementation.
I am not interfering with the position_ids
because I recollect they are not changed from the original ones i.e the positions will be referred from original one not the window!
Please share your thoughts on this, really appreciate it!
Hi @ArthurZucker,
I tried this drop-in replacement on a sample of "BillSum" dataset and the normal attention performs way better compared to the drop-in SWA! The drop-in doesn't even produce a fluent text, so I am not sure if this implementation is actually correct or if I am missing some details!
On the other hand, I fine-tuned the Llama-7B with Guanaco with SWA for 5 epochs and it is generating some text (still gibberish) compared to the drop-in replacement, but still it is also way off compared to the normal attention!
Here are few observations:
Here's the code snippet used for generating the output:
output = model.generate(input_ids,
max_new_tokens=300,
num_beams=3,
do_sample=do_sample_flag, #which I set to True and False to see different effects, usually sampling helps)
top_k=100,
temperature=10.0,
no_repeat_ngram_size=5,
attention_mask=attn_masks)
And here's the preprocessing function
prefix = "Summarize the following bill. Focus your summary on the most important aspects of the bill. You do not have to summarize everything. Particularly focus on questions related to appropriation and the effects and impacts of the bill. However, you do not need to go into complex details, it is acceptable to provide ranges. Use active verbs to describe the bill, like 'amends' or 'changes'. Do not use ambivalent verbs like 'proposes' or 'suggests.'"
def preprocess_function(examples):
inputs = [prefix + doc for doc in examples['text']]
model_inputs = tokenizer(inputs, padding=True, truncation=False,
max_length=1024, return_tensors='pt')
model_inputs["labels"] = examples["summary"]
return model_inputs
I've referred to Summarization article from here for most of the details: https://huggingface.co/docs/transformers/tasks/summarization#inference
Incase you have any thoughts, feel free to let me know! TIA :)
Feature request
Hi,
I have access to Llama-2 7B weights and am wondering how to write a wrapper which replaces the standard vanilla attention (or Grouped Attention) present in Llama-2 to SWA (as explained in Longformer (https://arxiv.org/abs/2004.05150v2) and implemented in Mistral-7B - https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L60)
Ideally, it should be:
Motivation
One can load the weights using
AutoModelForCausalLM
and instead of using the standard Attention block, this has to use the SWAClass. This ideally can help for faster inference.P.S: Most likely, a standard drop-in replacement of SWA from Vanilla might drop in performance! So, if there's any suggestion on how to recover the model's performance after the replacement, that would be super helpful!
Alternatively, if this is already implemented, please share the resources! I was unable to find any blogs/code-base except (https://github.com/lucidrains/local-attention)
Your contribution
I can contribute to the PR if there's some help in understanding how to proceed with this!