Closed danielhanchen closed 3 weeks ago
Fixes RMS Layernorm downcasting prematurely. We move it to the very end.
Fixes embedding matrix scaling / normalizer upcasting to float32. Instead we must use float16 or bfloat16 for the normalizer.
Thanks, @danielhanchen !
Fixes RMS Layernorm downcasting prematurely. We move it to the very end.![unnamed](https://github.com/google/gemma_pytorch/assets/23090290/8dbfc5e4-9a17-4fce-ac82-97835cc61cec)
Fixes embedding matrix scaling / normalizer upcasting to float32. Instead we must use float16 or bfloat16 for the normalizer.![unnamed-1](https://github.com/google/gemma_pytorch/assets/23090290/428f4d34-90a8-4556-b8d6-1670e0e41e38)