Add some easy checks for weight.requires_grad to skip allocating + calculating weight gradients if they're not needed. The weight gradient matrix can be pretty large, so this can also be a significant memory savings.
Also, a small micro-optimization: skip the .item() call on total_n_non_ignore (the subsequent calculations work fine with the tensor form) to defer CUDA synchronization (otherwise it will wait for all the torch.zeros initializations on the preceding lines to synchronize, which may take a non-trivial amount of time.)
Summary
Add some easy checks for
weight.requires_grad
to skip allocating + calculating weight gradients if they're not needed. The weight gradient matrix can be pretty large, so this can also be a significant memory savings.Also, a small micro-optimization: skip the
.item()
call ontotal_n_non_ignore
(the subsequent calculations work fine with the tensor form) to defer CUDA synchronization (otherwise it will wait for all thetorch.zeros
initializations on the preceding lines to synchronize, which may take a non-trivial amount of time.)Testing Done
The existing unit test already has a case where the weight does not have gradients enabled, and it still passes forwards/backwards: https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_fused_linear_cross_entropy.py#L165
And the preceding test verifies the 'normal' case where the weight gradients are needed.
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence