johnsmith0031 / alpaca_lora_4bit

MIT License
533 stars 84 forks source link

Implementing Landmark Attention #116

Open juanps90 opened 1 year ago

juanps90 commented 1 year ago

LandMark Attention:

While transformers have shown remarkable success in natural language processing, their attention mechanism's large memory requirements have limited their ability to handle longer contexts. Prior approaches, such as recurrent memory or retrieval-based augmentation, have either compromised the random-access flexibility of attention (i.e., the capability to select any token in the entire context) or relied on separate mechanisms for relevant context retrieval, which may not be compatible with the model's attention. In this paper, we present a novel approach that allows access to the complete context while retaining random-access flexibility, closely resembling running attention on the entire context. Our method uses a landmark token to represent each block of the input and trains the attention to use it for selecting relevant blocks, enabling retrieval of blocks directly through the attention mechanism instead of by relying on a separate mechanism. Our approach seamlessly integrates with specialized data structures and the system's memory hierarchy, enabling processing of arbitrarily long context lengths. We demonstrate that our method can obtain comparable performance with Transformer-XL while significantly reducing the number of retrieved tokens in each step. Finally, we show that fine-tuning LLaMA 7B with our method successfully extends its context length capacity up to 32k tokens, allowing for inference at the context lengths of GPT-4.

The first issue on their GitHub Page has the following comment:

yes currently it relies on standard finetuning to transform an existing model to longer context. conceptually, it should be perfectly reasonable to try LoRA-finetuning instead for more memory efficiency, or even Q-LoRA.

we're playing with LoRA at the moment. one change you'll have to make is to also unfreeze the initial embedding layer, to allow it to learn the new landmark tokens. otherwise, should work out of the box.

let us know if you manage to get it running by combining the two codebases. we can keep the issue open if other people share their experience as well

This repo has been my go-to for finetuning purposes and I'm super thankful. Being able to exceed the 2048 token mark would be perfect.