CStanKonrad / long_llama

LongLLaMA is a large language model capable of handling long contexts. It is based on OpenLLaMA and fine-tuned with the Focused Transformer (FoT) method.
Apache License 2.0
1.45k stars 85 forks source link

About the use of rotary position coding. #8

Open tianyabanbu opened 1 year ago

tianyabanbu commented 1 year ago

I have a doubt about the rotary positional encoding part of the code.

your code :


def rotate_as_if_first(x, rotary_emb):
    # x: [bs, num_attention_heads, seq_len, head_size]
    # apply rotary as if all elements were first in the sequence
    cos, sin = rotary_emb(x, x.shape[-2])
    return rotate_one(x, cos, sin, torch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device))

Should it be like this :


def rotate_as_if_first(x, rotary_emb, position_ids):
    # x: [bs, num_attention_heads, seq_len, head_size]
    # apply rotary as if all elements were first in the sequence
    cos, sin = rotary_emb(x, x.shape[-2])
    return rotate_one(x, cos, sin, position_ids)

When the function rotate_as_if_first calls the function rotate_one, the parameter position_ids needs to be passed in instead of generating a position parameter by torch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device) .

CStanKonrad commented 1 year ago

Hi, Thanks for the question! We always treat the memory keys as if they have position 0. Position ids inside the local context are converted to be in range 0, 2047 here https://github.com/CStanKonrad/long_llama/blob/0a19c56a983a49eb5a7624eb00da286cf4022a0e/src/modeling_longllama.py#L325

More context: Memory layers use positional encodings for local context in the standard way. Whereas for the memory keys, they encode them as if they were at the beginning of the local context.

In other words, let $$t_0, t_1, t_2, t3, \ldots t{2047}, t{2048}, \ldots, t{4095}, \ldots$$ be some input. LongLLaMA will process it in context windows. First, it will process $$t_0, t_1, t_2, t3, \ldots t{2047}$$ and move the (key, value) pairs from memory layers to the memory cache. The local context part ($t0, \ldots, t{2047}$) uses $2048$ rotary positional encodings. Then LongLLaMA will process $$t{2048}, \ldots, t{4095}$$ Here again the local context part ($t{2048}, \ldots, t{4095}$) uses the same $2048$ rotary positional encodings as the previous local context ($t0 \ldots t{2047}$). Memory layers see previous embeddings (keys and values corresponding to $t0, \ldots, t{2047}$), but as if they were located at the same position as $t_{2048}$ (what is position 0 after the conversion).

tianyabanbu commented 1 year ago

I see, thank you very much for your answer.