awslabs / extending-the-context-length-of-open-source-llms

Apache License 2.0
47 stars 4 forks source link

MistralLite: `max_position_embeddings=32768` and `precompute_freqs_cis` with `end=128_000` #11

Open keyboardAnt opened 10 months ago

keyboardAnt commented 10 months ago

Hello!

First off, I'd like to express my appreciation for the great work shared here. I've been examining the implementation details, specifically concerning the positional encoding. There's an ongoing discussion on this topic, and I thought someone here might have insights or clarifications about it.

Thank you in advance!

yinsong1986 commented 10 months ago

Hi @keyboardAnt

Thanks for your question!

max_seq_len used for fine-tuning MistalLite is ~16K.

In terms of positional encoding, MistalLite used the rope_theta as 1000000, so theoretically it should be able to support very long sequences (i.e., longer than 32k). But since our training data only used max_seq_len as 16K, we set the max positional encoding was set to 32768 to ensure the model still have good generalisation. Theoretically, you can change the max_positional_encoding to sth bigger than 32768, but you may see degraded model performance for input context length longer than 32768.

Hope this help you understand this topic. Thank you!

keyboardAnt commented 10 months ago

Hi @yinsong1986, thanks for replying!

I understand the theoretical part but am concerned about the precompute_freqs_cis function being called with a hardcoded end=128_000. How does this affect the attention span in practice?

yinsong1986 commented 10 months ago

I think not particularly, since in inference we can Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L97. We don't need to use the precompute_freqs_cis at all.