keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
740 stars 218 forks source link

Always run SiLU activation in float32 for LLaMA and Mistral #1540

Closed tirthasheshpatel closed 3 months ago

tirthasheshpatel commented 3 months ago

PyTorch's SiLU always runs in float32. Running in half-precision causes catastrophic cancellation and leads to huge errors. This PR fixes this issue for both LLaMA and Mistral.

Here's the PyTorch implementations:

CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235

CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu

Colab verifying this behavior: https://colab.research.google.com/drive/1v5CNVkWJtyIcQVbh-f51GKbqvrvfDyVd?usp=sharing