google / gemma.cpp

lightweight, standalone C++ inference engine for Google's Gemma models.
Apache License 2.0
5.94k stars 502 forks source link

Add Self-Extend to the gemma.cpp #60

Open namtranase opened 7 months ago

namtranase commented 7 months ago

Hi team, I checked the locallama and found that gemma can work well with the Self-Extend method. It would be awesome if this technique could be added to the gemma.cpp. References:

austinvhuang commented 7 months ago

This seems interesting and quite doable. I'll need to have a closer look at the paper and revisit tomorrow.

On the tactical side, we'll want to tidy up the APIs + dispatch mechanisms multiple alternative inference graphs. The dispatch mechanisms are ok for the limited set of 7B/2B x IT/PT but could use a refactor before we add more combinations of inference paths.

ahxt commented 7 months ago

Glad to see that our method works well with Gemma!! Our python implementation is here https://github.com/datamllab/LongLM/blob/master/gemma_self_extend_patch.py and the llama.cpp implementation is here https://github.com/ggerganov/llama.cpp/blob/cb49e0f8c906e5da49e9f6d64a57742a9a241c6a/examples/main/main.cpp#L569

We are glad to help!!!

Mooler0410 commented 7 months ago

Author here, glad to answer any questions about details for our work.

austinvhuang commented 7 months ago

If someone wants to take a stab at this as a flag, happy to have a look at the PR / provide suggestions (add yourself as the assignee for this issue).

There's an enhancement that i think would improve the usefulness of this is %save %load commands for KV cache state. Using the blob store headers, I think this wouldn't be that hard to implement. Might be a good first issue for someone who's comfortable with the codebase. I think this would lead to a lot of use cases that would otherwise be impractical.

jan-wassenberg commented 2 months ago

+1, we'd welcome a pull request for this, also happy to discuss.

jonpsy commented 1 month ago

@austinvhuang @jan-wassenberg I'd like to take a stab at this, if you nobody has objections?

My background: I've been trying to break into this field, and I've had the pleasure of collaborating with the Google Team in the past for TFLite Support repository.

jan-wassenberg commented 1 month ago

Nice, sounds great, we'd be happy to collaborate with you, discuss and review :)

FYI the KVCache internals will likely change a bit to use RowVectorBatch at some point, but no big deal.

Is there anything in the current code that you think will cause difficulties?

InferenceArgs is probably a good place to add the flag.

jonpsy commented 1 month ago

Perfect, sorry for the delay, I can spin something up over the weekend. Please allow some time to read the codebase and get back with a proposal

jonpsy commented 1 month ago

Had a first pass through the paper, the paper has proven its ability only on RoPE position encodings, and the theory is supported only for relative position encodings. i.e. there's no proof of it working if we were training via sinusoidal positional encoding.

Shouldn't we have some kind of check for this?

cc: @Mooler0410 @ahxt

image
jan-wassenberg commented 1 month ago

Hi, https://arxiv.org/pdf/2401.01325 mentions an experiment with non-RoPE also working. We mostly do use RoPE, though. We can certainly mention in the flag description that this depends on the positional encoding.

jonpsy commented 4 weeks ago

@jan-wassenberg @austinvhuang @Mooler0410

Okay, it took me a long time to understand how transformers work and even longer to understand how this repository is implementing that. So please forgive some mistakes below and do comment if I'm making any mistakes

The paper argues that O.O.D occurs due to it not having seen the positions outside of its trained context window. So it uses grouped attention by grouping position into gbatch_size/group_size_1 s.t position ids is shared within a group. It's applied on position ids > nb_window_size/group_size_2.

For reference here is the code snippet from LongLM

image

We already have RoPe implemented as PostQK.

In Gemma, we have

// gemma/activations.h

class Activations {

 // Same as before, need space for s_g_pos states
 RowVectorBatch<float> grp_q; (size: (batch_size, kHeads * QKVDim) // only Q head is flattened
}

We use do fwd pass twice once for PreFill and once for GenerateT

I think q matrix has the following shape

# rows = token batch size
row 1 [ ----- head 1 (QKV flattened) ------, ---- head 2 (QKV flattened) ----, --- head 3 (QKV flattened) ---- ]

row 2 [ ---- head 1 (QKV flattened) -----, --- head 2 ---...... (Same as before)
  1. Would need a flag inside PostQK so that when applying positional encodings for K, we apply positional encoding grp_k as well.

    
    // Apply positional encodings for K (and copy KV to cache if MHA).
    pool.Run( 
    ....
           if constexpr (kIsMHA) {
          // For MHA, copy KV into the KV cache from scratch space (see above).
          const float* HWY_RESTRICT q =
              activations.q.Batch(interleaved_idx) + head * kQStride;
    
          const grp_k = kv_caches[query_idx].g_k_cache() // Another float pointer to store grouped key cache? 
    
          // Apply Rope on Grouped keys
          const size_t ngb_size = TConfig::self_extend_ngb_size;
          const size_t grp_size = TConfig::self_extend_grp_size;
    
          // First, group the key positional embedding
          const grp_k_pos = pos / grp_size; 
    
          // Now apply RoPE based on this, will come in handy later
          RoPE(grp_k, qkvDim, grp_k_pos);

2. We also need to compute `grp_q` as we stored in Activations and calculate score
```cpp
// This should be done during our fwd pass
  // For each head (token, query), compute Q.K, softmax, and weighted V.
  pool.Run(
      0, kHeads * num_interleaved,
      [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
        ...
        KVCache& kv_cache = kv_caches[query_idx];
        float* HWY_RESTRICT q =
            activations.q.Batch(interleaved_idx) + head * kQStride;

        // Apply rope and scaling to Q.
        const size_t pos = batch_start + batch_idx;
        PostQK<TConfig>(q, pos, layer);
        MulByConst(kQueryScale, q, kQKVDim);

       // Now let's find g_q based on new pos
       if (pos > ngb_size && TConfig::kEnableSelfExtend) {
          g_q = // Same as activations.q which we got during MatMul stored in a different memory block
          const size_t s_g_pos = ngb_size  + (pos - grp_size) / ngb_size;
          PostQK<TConfig>(g_q, s_g_pos, layer);
       }

       // Score using g_q and g_k not query and K
       if (pos > ngb_size && TConfig::kEnableSelfExtend) 
       const score = Dot(g_q, grp_k, QKVDim);
jonpsy commented 4 weeks ago

If you're interested in mentoring me to help merge this, I can convert the above into a Tech doc in Google Docs and you can point out things there.

Awaiting your positive response

jan-wassenberg commented 3 weeks ago

Hi, great that you've mapped this to our code :) This sounds correct to me. Note that activations.q shape depends on QStride. For MHA it is indeed Q,K,V, otherwise only Q.

Would be happy to advise, but I will be without internet access from Sep11-17. A doc sounds fine, or we can also proceed directly to a pull request if you prefer?

jonpsy commented 3 weeks ago

I'd love that. I'm aware of values being stored as [Q, Q,...., K, K ...., V, V..] in non MHA cases. I was under the assumption that we'll be restricting it to purely MHA case, guess that's wrong.

I was thinking of a doc so that we could have a unified tracker, objectives and next-plans where you could easily comment/edit. For reference, I've done something similar with TensorFlow team when I collaborated with them.

Great! I can definitely fast-track a pull request for this! I understand you wouldn't be available from Sept11-17, given I'm also handling a full-time job on the side. I can try and spin a MVP while you're unavailable.

Let me know if this arrangement works for you.

jonpsy commented 2 weeks ago

I have holidays from 23rd - 27th, where I can dedicate good time to this and do a bulk amount of work. I'd love go back-n-forth with the feedbacks and suggestions so we can quickly integrate this amazing feature!

Just FYI, we'd also need to plan how we'd test if its working and benchmark to confirm its practicality. I can make an .md file for the report.

jan-wassenberg commented 2 weeks ago

:) Yes, the various Gemma models can be any of MHA, GQA or MQA, so we have to handle both the MHA and non-MHA case.

Sure, a doc sounds useful for commenting. Thanks for suggesting it :)

Yes indeed, I'm back and catching up. Working on this together next week sounds great!

plan how we'd test if its working and benchmark to confirm its practicality

I am not very familiar with evals in this space, but understand that summarization and needle-in-haystack tests might work. Do you or others have any particular suggestions?

jonpsy commented 1 week ago

@jan-wassenberg yes! The original code proposed needle-in-a-haystack problem and they also have a dataset for this. See for example

They've used Gemma-7B-it for one of their benchmark as is evident here which we do have present with us.

jan-wassenberg commented 1 week ago

Got it, needle in a haystack sounds like a good eval then :)