Closed ggerganov closed 11 hours ago
Trying to figure out this. I found this reference to be helpful, as it seems to emplace all necessary code in one place: https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py
Looks like Attention.forward()
Py function have something in common with build_llama() // self-attention
C++ part
Also, as KQ_mask
already exists, it looks like, there isn't many changes needed indeed.
Probably I can manage it in a week or two, but not sure about it.
Any news on that one ? 🥹 The best low memory usage models are actually based on mistral and being locked with 4096 window is very limiting for different task like document analysis.
I think KoboldCPP introduced a similar feature with their recent rewrite of the "smartcontext" feature.
I think KoboldCPP introduced a similar feature with their recent rewrite of the "smartcontext" feature.
Nope, they added cache shifting; sliding window attention is a different attention mechanism; they do two very different things.
Maybe it's not appropriate to insist on it ( I apologize if this bothers you ) but this feature seems to be one with the most thumbs up on the dashboard. 😢
This issue was closed because it has been inactive for 14 days since being marked as stale.
It feels that since Mistral 7B from last year, there hasn't been much interest in this technique. Even later Mistral models dropped it as a feature. Taking this into account, I guess we can leave this issue closed
I guess we can leave this issue closed
@ggerganov As the new gemma 2 models use SWA (in addition to GQA, in some sort of alternating scheme?), I suggest this be revisited. As-is, gemma 2 pretty much falls apart after 4k context or so using llama.cpp.
Probably I can manage it in a week or two, but not sure about it.
Just for the record: I've checked this again in Jan or Feb, but didn't manage. Not enough of my understanding, so, I don't expect to finish it soon.
Still it really seems, that only few additional lines of code needed.
I think it may be worth re-evaluating this, since there's increased number of models supporting SWA (mistral, phi, gemma2 on top of my head).
I roughly implemented sliding window attention here: https://github.com/arlo-phoenix/llama.cpp/tree/gemma2 the branch is already rebased on #8197 so this should fix all gemma2 bugs.
No idea if it's correct, output isn't great yet. But it doesn't completely break like it does without it. For testing I just gave the 9b-it model the bee movie script until the I like jazz part (~7000 tokens) and it managed to generate an ending of Barry leaving through the window after some conversation (previously just repeated random stuff like You're a bee. I'm a bee. lol)
My change does fulfill the "each token can attend to at most W tokens from the previous layer" description from the Mistral paper. The mask isn't fully equal to the mistral implementation since that one does a log at the end, but I don't think that's related to SWA. It's also missing gguf parameters, just hardcoded rn (only enabled for gemma2 though)
It seems like the sliding window technique used by gemma 2 is mostly to reduce memory usage for KV cache. The idea is to use sliding window for only some of the layer (a bit like jamba where it swap out some layer with mamba to save memory)
So the idea would be to start by initializing different KV size for each layer:
// inside llama_kv_cache_init()
for (int i = 0; i < (int) n_layer; i++) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
int kv_size_l = (i % 2 == 0)
? n_embd_k_gqa*kv_size // full KV size for even layers
: n_embd_k_gqa*4096; // "limited" size for the rest
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, kv_size_l);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, kv_size_l);
...
}
And then there're patches that need to implement for llm_build_kv_store
, llm_build_kqv
, etc. I haven't looked deeply into that.
I had a look on @arlo-phoenix fork but I'm not sure that's the right direction. KQ_mask is used for masking different sequences inside a batch, so probably unrelated here. (see comment below)
I did it there since I thought ggerganov suggested it with the linked #3228. And it does work. I’m pretty sure KQ_mask also isn‘t just for batch masking, but also general positional masking (the > pos check). At least made the most sense to me to do it with the mask. But yeah definitely not the most efficient way to implement it, rest goes over my head.
@arlo-phoenix sorry I overlooked. KQ_mask has size [kv_size, n_batch]
, so clearly it also masks tokens in the kv, not just the batch. If you don't mind, I can propose a more clean PR based on your later today. Even if it's not the most efficient way, I believe it can be a good start.
@arlo-phoenix sorry I overlooked. KQ_mask has size
[kv_size, n_batch]
, so clearly it also masks tokens in the kv, not just the batch. If you don't mind, I can propose a more clean PR based on your later today. Even if it's not the most efficient way, I believe it can be a good start.
@ngxson Yeah sounds good! Would be too large a change for me to do clean anyways, so thank you for doing it instead! I think you saw the hacky commit https://github.com/arlo-phoenix/llama.cpp/commit/265a8f2d0fe73ba928e98f8a765105c94c3165c2 since that one requires the cleaning (for the every other layer SWA, then global as you also commented above. Just copy pasted, actual difference is minimal). I only propose that the default SWA size is kept at gemma2 size since that's what most people are interested in right now (and we already did the same for the other gemma2 things) so people don't need new quants.
Even gemma.cpp, the reference implementation by google is giving me subpar results.
The best implementation is by @foldl in his chatllm project. It's giving the exact same results as the aistudio version of gemma 27b.
Very neat code of chatllm.. Really liked his code!!
How much more difficult would it be to add a similar change for the mistral 7b architecture following the changes above for gemma2? Trying to compare with what was done here
Implementation of sliding window attention in mlc-llm
This issue was closed because it has been inactive for 14 days since being marked as stale.
For more info, see: https://github.com/mistralai/mistral-src and references there in.
Also: https://arxiv.org/pdf/2310.06825v1.pdf
With #3228 it should be relatively easy to support this.