Closed ByronHsu closed 3 weeks ago
It would be great to have z-loss related functionality (i.e. lse_square_scale
, return_z_loss
) in cross entropy.
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py
This is supported by https://github.com/linkedin/Liger-Kernel/pull/153
🚀 The feature, motivation and pitch
Now we hardcode the reduction as "mean", which means we don't have feature parity with https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html.
Note: Label smoothing is also missing. It is in another issue https://github.com/linkedin/Liger-Kernel/issues/81
Alternatives
No response
Additional context
No response