Closed awfulcooking closed 6 months ago
The true context window of transformers is effectively baked in to the model. It's an inherent property, and it's tied to how the model learns positional embeddings and all of its layer states. If you want larger context you must completely retrain the model. Context is typically limited because of the computational complexity of the type of attention used by these models. However there are some variants like longformer which support 16k context.
There used to be recurrent networks (RNN, LSTM) that are infinite but they are too hard to train using backpropagation and so after "Attention is all you need" the everybody switched to transformer models.
EDIT: Apparently RNN's are not dead yet: https://github.com/ggerganov/llama.cpp/issues/846
So the abstract from the Longformer paper answers some questions about computational complexity:
Transformer-based models are unable to process long sequences due to their self-attention operation, which scales quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention mechanism that scales linearly with sequence length
They then say that mere linear scaling "makes it easy" !
Has a Longformer inspired mechanism made it into any of the huge models?
It still ultimately depends on a discrete set of positions for text to be in, right, modelled like pigeonholes / slots.
Maybe there's a way to project it continuously.
In theory, you could approximate it by using the model to compress the existing context window. For example, let's say you have a maximum window of 2048. Once you get 2048 tokens of context, take the first
You can use a layer on top of llama.cpp that either:
The memory required for larger context windows scales quadratically with the window size. 100x lossless context window is impractical from what I understand.
A PR that implements what chatgpt does would be neat, https://www.pinecone.io/learn/langchain-conversational-memory/ but that isn't much of an increase. It also has the issue that summarizing usually loses important information, its at the whims of what the language model finds important.
I am curious if it would be possible to bypass the low compression of written language by adding custom embedding tokens that represent phrases/concepts instead of letters, would need to get tuned/trained in, and I guess set to zero probability when sampling.
edit: i guess that idea would be a multi-modal model
finetuning / lora with a longer max sequence length?
The max context length is not something that can be extended because there are no weights for it. I think it may be possible to mess with the KV cache to compress it down or something.
The context length isn’t associated with weights is it? The position encodings are summed into each token embedding, the existing weights would simply need to be finetuned for the new position encodings, no?
I woke up thinking about interpolating input vectors, and/or the early representations that they make - whatever is created when the attention matrices are applied to the input embedding.
For example, could you blend "Explain quantum physics in easy to understand terms", "Gandalf", "Cyberpunk" and "Lord of the Rings", in an associative / multiplicative way that uses many more bits, and bits not associated directly with positional encoding, to move around in the latent / representational space in a new interesting way.
Like how visual AIs can blend between arbitrary points in their representation space.
In some way, maybe the compression of particularly large contexts that we do in our own brains, is more akin to that than sentence stringification.
For whatever it's worth.
I was thinking that maybe the K and V buffers can be manipuleted using some image processing techniques like maxpool or convolutions.
If I were innovating something new here, my present idea is to organize an architecture similar to the concepts used in langchain, autogpt, and llama-index, but in-training: that is, I might wire the layers so they can navigate and update knowledge graphs of logits and weights, with concepts of nearness and relevance, and finetune them to learn to do this.
The context length isn’t associated with weights is it? The position encodings are summed into each token embedding, the existing weights would simply need to be finetuned for the new position encodings, no?
I was wrong about this, I thought there were tensors that had size 2048 in them, but there aren't. The KQV weights are applied to the layer input size (so 4096). The self attention happens on all the past context, the huge matrixes are multiplied but no matter the size of the context, the output is still the embedding size that is passed to the next layer or is the result.
So it should be possible to increase the limits without any issues? Yes. You need to change the hard-coded limit in llama.cpp and also increase the memory to be allocated. Everything works, there are no crashes.
Except it doesn't work, once you go past some token number, everything breaks down and the output becomes more and more garbage. I don't know what the reason is, it could also be technical like NaNs or Infs appearing or maybe it is never going to work because at some point we are "taking the average" of too many tokens and there is nothing meaningful to extract from it.
the huge matrixes are multiplied but no matter the size of the context, the output is still the embedding size that is passed to the next layer or is the result.
That makes it sound enticing to try averaging or otherwise interpolating those matrices given by certain inputs.
Edit: Which I see you may have promptly done 🤓
Except it doesn't work, once you go past some token number, everything breaks down and the output becomes more and more garbage. I don't know what the reason is, it could also be technical like NaNs or Infs appearing or maybe it is never going to work because at some point we are "taking the average" of too many tokens and there is nothing meaningful to extract from it.
That is one possible outcome. It could be that it doesn't make sense to move around in that space in any obvious ways.
Maybe it does have position information embedded in, and this deranges it.
But it's a neat place to start.
Could you model position with something like a signed distance function?
Or maybe some kind of modular arithmetic. So you can fractally divide relative distances in self-contained modular units.
I don't know how well founded those thoughts are. I'm an amateur / ignorant of the brass tacks as mentioned. I'm trying to flick sparks at smart people from enthusiast intuition.
I was thinking that maybe the K and V buffers can be manipuleted using some image processing techniques like maxpool or convolutions.
Sounds like there are many more experiments you could try 👍
Sounds like there are many more experiments you could try :+1:
I plan to.
There are some interesting things to see in the KV cache, for example, it seems the K data seems to change little from token to token.
It is not easy to analyze it right now, you have to modify the code to dump the data and then you can load it using Numpy and visualize with matplotlib.
You almost want a dedicated tool to visualise various possible generation buckets. To get a sense for the range with a certain function applied whether it's temperature or something more adventurous ^.
Except it doesn't work, once you go past some token number, everything breaks down and the output becomes more and more garbage. I don't know what the reason is, it could also be technical like NaNs or Infs appearing or maybe it is never going to work because at some point we are "taking the average" of too many tokens and there is nothing meaningful to extract from it.
Yes, I was meaning to finetune on longer data so it would learn the new position encodings exist at all.
It’s notable there have now coincidentally been published models finetuned to longer context lengths.
I was thinking of this a little more this morning, and similarities between training and inference, and I’m wondering if you could fold a prefixed prompt into some of the model weights, such that the model functions as if it always had this prompt, without requiring allocating ram to process it. Assuming you can do that, you could then train models in an RNN way with infinite context by folding and unfolding weights.
Anything I’m missing with that idea?
I’m wondering if you could fold a prefixed prompt into some of the model weights
Something like this we saw with #1472 where another prompt could influence the input, and it only needed processing on one layer.
RNN way with infinite context
Probably not infinite, at some point the numbers just average out too much and the information is lost.
There are some interesting things to see in the KV cache, for example, it seems the K data seems to change little from token to token.
Since the rope scaling, might be dK in token distances could lead to something?
Surprised no one mentioned alibi, there are ways to train the model at size X during the training and it can extrapolate to Y > X during inference (alibi being one method). Check out some blog posts from MosaicML.
this paper was just released that not only describes effective finetuning for long context but publishes long context checkpoints of llama (up to 128k context length): https://huggingface.co/papers/2309.00071 https://github.com/jquesnelle/yarn
I’m wondering if one wrote some smart graph code, you could do a top-k attention over the sequence, and then at deeper layers see if there are strong potential logits that aren’t available, and go back up to shallower layers and fill in the logits. Making the model a little more like when you skim a book to figure out what to read. Then for training you could do the top-k tokens like a mixture of experts, sampling from them, and weighting the output based on the likelihood of the chosen ones.
I websearched briefly for that and found instead the different approach at https://github.com/abertsch72/unlimiformer which claims unlimited context length with llama 2. some community algorithm and optimization notes at https://www.reddit.com/r/MachineLearning/comments/138atnt/r_unlimiformer_longrange_transformers_with/
This issue was closed because it has been inactive for 14 days since being marked as stale.
Thinking about what could be done when large language models can operate on phenomenally large context, and wondering what it might actually take to get there.
And realised this repo has a ton of really bright people in orbit, who actually understand brass tacks what might be involved.
Assuming it's really desirable, what hacks could be done to get there?
🙏