keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
740 stars 218 forks source link

Fix the rotary embedding computation in LLaMA #1544

Closed tirthasheshpatel closed 3 months ago

tirthasheshpatel commented 3 months ago

LLaMA backbone ignored the start_index parameter when computing the rotary embeddings which lead to numerical issues during generation. This PR fixes it along with the reverse embedding layer in both Mistral and LLaMA: run the reverse embedding stage in compute_dtype instead of full-precision. This is how HF does it, so helps get the numerics closer.