CStanKonrad / long_llama

LongLLaMA is a large language model capable of handling long contexts. It is based on OpenLLaMA and fine-tuned with the Focused Transformer (FoT) method.
Apache License 2.0
1.45k stars 85 forks source link

FoT attention and the scaling trick #4

Open StrangeTcy opened 1 year ago

StrangeTcy commented 1 year ago

In your paper, you say

Position Interpolation (PI, [Chen et al., 2023] and [kaiokendev, 2023]) introduces a modification to the rotary positional encoding scheme that enables fine-tuning for 32K context. In contrast to this work, our method does not rely on positional encodings, following the findings from [Haviv et al., 2022]. Removing positional encoding in memory allows us to extrapolate to 256k tokens, although the model was only trained on sequences up to 8K, yielding theoretically unbounded context length.

Does that mean that one can't use both scaled positional embeddings and FoT attention?

soacker commented 1 year ago

I think its due to applied FoT attention, that not use scaled positional embeddings by summing the additional parts

CStanKonrad commented 1 year ago

Hi, thanks for the question. Briefly speaking, we have not tried using scaled positional encodings and FoT attention, so we cannot comment on performance.

Originally FoT was designed to allow the model to handle large databases consisting of millions of keys and values from multiple unrelated documents. In such a setup, it is not clear how to apply positional encodings. It is reflected in our experiments with smaller models where we disable positional encodings in memory layers (other layers maintain positional encoding). There is a slight difference in LongLLaMA models. Mainly all layers except memory layers use positional encodings in the standard way. Memory layers use positional encodings for local context in the standard way. Whereas for the memory keys, they encode them as if they were at the beginning of the local context.

In other words, let $$t_0, t_1, t_2, t3, \ldots t{2047}, t{2048}, \ldots, t{4095}, \ldots$$ be some input. LongLLaMA will process it in context windows. First, it will process $$t_0, t_1, t_2, t3, \ldots t{2047}$$ and move the (key, value) pairs from memory layers to the memory cache, Then it will process $$t{2048}, \ldots, t{4095}$$ In this step, non-memory layers process only 2048 embeddings, whereas memory layers see also previous embeddings (keys and values), but as if they were located at the same position as $t_{2048}$.

We do this in order to maintain compatibility with the LLaMA code.

StrangeTcy commented 1 year ago

I figured as much after a re-reading of the respective parts of the paper, but the whole "they encode them as if they were at the beginning of the local context" wasn't very clear to me until your explanation, so thanks for that.