keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
740 stars 218 forks source link

Run the LLaMA and Mistral RMS Layer Norm in float32 #1532

Closed tirthasheshpatel closed 3 months ago

tirthasheshpatel commented 3 months ago

LLaMA and Mistral Layer Norm should always run in float32. This PR corrects this bug in our implementation.

tirthasheshpatel commented 3 months ago

@mattdangerw Addressed the review comments. Let me know if the diff looks good to you now!

mattdangerw commented 3 months ago

Looks good besides that one potential name change. Thanks!