Closed lancerts closed 2 months ago
add a unit test to catch the bug?
will do after we bump the pytorch to 2.3-2.4. It was added recently https://github.com/pytorch/pytorch/pull/121364 so 2.1, 2.2 cannot find the baseline torch implementation.
How about just adding a manual one by removing .to(torch.float32)
from https://github.com/linkedin/Liger-Kernel/blob/d338f4b9923e452baecff6d36775242a5319df4c/test/transformers/test_rms_norm.py#L25 ?
Summary
Fix the logic of RMSNorm, and isolate the casting logic from the computation logic. If casting model =None, it does not perform casting but still performs the computation.