google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.19k stars 492 forks source link

Is max_position_embeddings=8096 neccessary in 2b model? #41

Open agiwave opened 4 months ago

agiwave commented 4 months ago

I just try to do some small changes on model '2b' 1, Limit max_position_embeddings from 8096 to 256. :) 2, Trim kv-cache in GemmaAttention to max_position_embeddings(256). 3, Unlimit the output length of model.generate The generate work is still working fine and can generate about 400 tokens for question "The life meaning is".

Is that means? 1, The too old kv-caches is not neccessary and the model can store and compress long-context info into 256 kv-cache(18-layers)? 2, Could have a try on training model this way(only max 256 kv-cache)? 3, If above is true, Does this means that we can decrease the training and generating complexity tremendously from O(LLD) to O(256LD) = O(L*D)?

agiwave commented 4 months ago

Maybe we can extend Gemma context-length to unlimited size(depend on the compress-rate) in this way(with limited kv-cache length - 256 or little more?) in linear complexity.

pengchongjin commented 4 months ago

2, Trim kv-cache in GemmaAttention to max_position_embeddings(256).

Do you mean using a slide window of size 256 as you generate the output tokens?

I think this is an interesting observation. I believe there are some related work in the literature which tries to use sliding window to extrapolate the context. It sounds like you are doing similar things.

agiwave commented 4 months ago

Yeah. I try to limit max_position_embeddings to 256 and generate beyond 400 tokens answer. It looks work well. I wish the model can compress the context info far beyond 256 token first. So, I tried it. But, unfortunately. I told the model my name first, and followed by 300 tokens about another infos. At last , I try to ask Gemma " Do you know what's my name". Gemma couldn't give me the right answer. So gemma has no sliding windows memory. This test only work in 256 tokens(Attention scope). Emm, a little bit lose here :).