linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
https://arxiv.org/pdf/2410.10989
BSD 2-Clause "Simplified" License
3.38k stars 190 forks source link

Update the casting logic of RMSNorm #201

Closed lancerts closed 2 months ago

lancerts commented 2 months ago

Summary

Fix the logic of RMSNorm, and isolate the casting logic from the computation logic. If casting model =None, it does not perform casting but still performs the computation.

lancerts commented 2 months ago

add a unit test to catch the bug?

will do after we bump the pytorch to 2.3-2.4. It was added recently https://github.com/pytorch/pytorch/pull/121364 so 2.1, 2.2 cannot find the baseline torch implementation.

yundai424 commented 2 months ago

How about just adding a manual one by removing .to(torch.float32) from https://github.com/linkedin/Liger-Kernel/blob/d338f4b9923e452baecff6d36775242a5319df4c/test/transformers/test_rms_norm.py#L25 ?