linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
https://arxiv.org/pdf/2410.10989
BSD 2-Clause "Simplified" License
3.3k stars 179 forks source link

Support reduction = None and "sum" in Cross Entropy #83

Closed ByronHsu closed 3 weeks ago

ByronHsu commented 2 months ago

🚀 The feature, motivation and pitch

Now we hardcode the reduction as "mean", which means we don't have feature parity with https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html.

Note: Label smoothing is also missing. It is in another issue https://github.com/linkedin/Liger-Kernel/issues/81

Alternatives

No response

Additional context

No response

skyshine102 commented 1 month ago

It would be great to have z-loss related functionality (i.e. lse_square_scale, return_z_loss) in cross entropy. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py

ByronHsu commented 3 weeks ago

This is supported by https://github.com/linkedin/Liger-Kernel/pull/153