turboderp / exllama

A more memory-efficient rewrite of the HF transformers implementation of Llama for use with quantized weights.
MIT License
2.66k stars 214 forks source link

RoPE Frequency Base and Frequency Scale Support #262

Open ChrisCates opened 10 months ago

ChrisCates commented 10 months ago

As of now, there is no way to modify RoPE Frequency Base and RoPE Frequency Scale.

We would need to edit rope.cu to support parameters for frequency and scale: https://github.com/turboderp/exllama/blob/21f4a12be5794692f66410ad4fb78ffaad508d00/exllama_ext/cuda_func/rope.cu#L21-L31

We would also need to add arguments in model_init.py to support frequency and scale for RoPE: https://github.com/turboderp/exllama/blob/21f4a12be5794692f66410ad4fb78ffaad508d00/model_init.py#L29-L30

Here is a proposed argument to be added to the existing model_init.py:

    parser.add_argument("--rope-freq-base",  type = int, help = "The frequency base for the RoPE Kernel", default=10000)
    parser.add_argument("--rope-freq-scale",  type = int, help = "The frequency scale for the RoPE Kernel", default=1)

Note that this is important to resolve issues like #261 and #260 when context length is larger during inference.

Ph0rk0z commented 10 months ago

It's exactly the same as alpha. BTW the "base" for codellama base is about alpha 100.

ChrisCates commented 10 months ago

@Ph0rk0z thanks man! I was wondering why I couldn't find the relevant source. But, just found it.

https://github.com/turboderp/exllama/blob/21f4a12be5794692f66410ad4fb78ffaad508d00/model.py#L126-L127

ChrisCates commented 9 months ago

As per discussion in issue #270. This issue is being reopened. The following is a fairly informal proposal for @turboderp to review:

Instead of replacing the current rotary embedding calculation. We have optionality for two. Utilizing rope_alpha and rope_theta for the first calculation and rope_base and rope_frequency for the second. We should change the --alpha flag to --rope-alpha for extra clarity. We should use something like --use-rope-alpha and --use-rope-base to flag for calculation types.

Second, let's just pull the calculations done with RoPE from the llama.cpp repo. This will be easier and faster and given the nature of how rotary embeddings function, should not be a problem.

Third, while not necessary, an additional testing script for PPL and maybe reviewing sample outputs would be nice. Just to see what are the optimal alpha, theta or base and frequency values are. This is up for discussion and should be a separate PR.

I'd be happy to formalize this into a spec now. In terms of implementation. I will take a deep dive in a couple weeks assuming no one else is working on it.