Open tianyabanbu opened 1 year ago
Hi, Thanks for the question! We always treat the memory keys as if they have position 0. Position ids inside the local context are converted to be in range 0, 2047 here https://github.com/CStanKonrad/long_llama/blob/0a19c56a983a49eb5a7624eb00da286cf4022a0e/src/modeling_longllama.py#L325
More context: Memory layers use positional encodings for local context in the standard way. Whereas for the memory keys, they encode them as if they were at the beginning of the local context.
In other words, let $$t_0, t_1, t_2, t3, \ldots t{2047}, t{2048}, \ldots, t{4095}, \ldots$$ be some input. LongLLaMA will process it in context windows. First, it will process $$t_0, t_1, t_2, t3, \ldots t{2047}$$ and move the (key, value) pairs from memory layers to the memory cache. The local context part ($t0, \ldots, t{2047}$) uses $2048$ rotary positional encodings. Then LongLLaMA will process $$t{2048}, \ldots, t{4095}$$ Here again the local context part ($t{2048}, \ldots, t{4095}$) uses the same $2048$ rotary positional encodings as the previous local context ($t0 \ldots t{2047}$). Memory layers see previous embeddings (keys and values corresponding to $t0, \ldots, t{2047}$), but as if they were located at the same position as $t_{2048}$ (what is position 0 after the conversion).
I see, thank you very much for your answer.
I have a doubt about the rotary positional encoding part of the code.
your code :
Should it be like this :
When the function
rotate_as_if_first
calls the functionrotate_one
, the parameterposition_ids
needs to be passed in instead of generating a position parameter bytorch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device)
.