linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
BSD 2-Clause "Simplified" License
2.93k stars 147 forks source link

Optimize fused_linear_cross_entropy when weight does not require grads #237

Closed hansonw closed 1 week ago

hansonw commented 1 week ago

Summary

Add some easy checks for weight.requires_grad to skip allocating + calculating weight gradients if they're not needed. The weight gradient matrix can be pretty large, so this can also be a significant memory savings.

Also, a small micro-optimization: skip the .item() call on total_n_non_ignore (the subsequent calculations work fine with the tensor form) to defer CUDA synchronization (otherwise it will wait for all the torch.zeros initializations on the preceding lines to synchronize, which may take a non-trivial amount of time.)

Testing Done

The existing unit test already has a case where the weight does not have gradients enabled, and it still passes forwards/backwards: https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_fused_linear_cross_entropy.py#L165

And the preceding test verifies the 'normal' case where the weight gradients are needed.