keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
758 stars 227 forks source link

Always run the rotary embedding layer in float32 #1508

Closed tirthasheshpatel closed 5 months ago

tirthasheshpatel commented 5 months ago

Follow-up for #1497

This PR refactors the keras_nlp.layers.modelling.rotary_embedding.RotaryEmbedding layer to always compute in float32 dtype since there are significant precision losses in other dtypes. Also update Gemma to use this layer instead of implementing its own version of RoPE.

This PR isn't ready yet. TODO:

Colab showing the equivalence of Gemma's embedding and the rotary embedding in KerasNLP: https://colab.research.google.com/drive/1BNNlxN7Y7yAzJl0UeWdG9TZ6RpfJjCBS?usp=sharing

tirthasheshpatel commented 5 months ago

Code looks good!

We probably should test this to make sure numerics are as close to our reference jax implementation as they were before, and that this does not negatively impact performance.

Already done here: https://colab.research.google.com/drive/1BNNlxN7Y7yAzJl0UeWdG9TZ6RpfJjCBS?usp=sharing