google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.19k stars 492 forks source link

Fix downcasting and upcasting #67

Closed danielhanchen closed 3 weeks ago

danielhanchen commented 3 weeks ago
  1. Fixes RMS Layernorm downcasting prematurely. We move it to the very end. unnamed

  2. Fixes embedding matrix scaling / normalizer upcasting to float32. Instead we must use float16 or bfloat16 for the normalizer. unnamed-1

pengchongjin commented 3 weeks ago

Thanks, @danielhanchen !