Closed tirthasheshpatel closed 3 months ago
PyTorch's SiLU always runs in float32. Running in half-precision causes catastrophic cancellation and leads to huge errors. This PR fixes this issue for both LLaMA and Mistral.
float32
Here's the PyTorch implementations:
CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235
CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu
Colab verifying this behavior: https://colab.research.google.com/drive/1v5CNVkWJtyIcQVbh-f51GKbqvrvfDyVd?usp=sharing
PyTorch's SiLU always runs in
float32
. Running in half-precision causes catastrophic cancellation and leads to huge errors. This PR fixes this issue for both LLaMA and Mistral.Here's the PyTorch implementations:
CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235
CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu
Colab verifying this behavior: https://colab.research.google.com/drive/1v5CNVkWJtyIcQVbh-f51GKbqvrvfDyVd?usp=sharing