linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
BSD 2-Clause "Simplified" License
2.93k stars 147 forks source link

Support Z Loss in CE #239

Open Tcc0403 opened 1 week ago

Tcc0403 commented 1 week ago

Summary

This PR aims to resolve #197

Implemented z loss in LigerCrossEntropy.

note: lse_square_scale not exposed at flce yet, having issues passing the tests.

Details

For loss:

\begin{align}
L_{total} &= L_{ce} + z\_loss\
z\_loss &= lse\_square\_scale \cdot lse^2\
lse &= log \sum e^{X_i}
\end{align}

We can use $m = max(X_i)$ and $d = \sum e^{X_i - m}$, obtained from online softmax algorithm, to calculate $lse$ directly.

\begin{align}
lse &= log \sum e^{X_i}\
     &= log \sum e^{X_i - m + m} = log \sum e^{X_i -m} \cdot e^m\
     &= log\ e^m\sum e^{X_i - m} = m + d
\end{align}

For gradients:

First, we calculate the derivative of lse

\begin{align}
\frac{\partial}{\partial x_i}(lse) &= \frac{\partial}{\partial x_i}(log \sum e^{x_i}) \
                                           &= \frac{1}{\sum e^{x_i}} \cdot  \frac{\partial}{\partial x_i} \sum e^{x_i}\
                                           &= \frac{e^{x_i}}{\sum e^{x_i}} = softmax(x_i).
\end{align}

Then we can obtain the derivative of z_loss by chain rule.

\frac{\partial z\_loss}{\partial x_i} = \frac{\partial}{\partial x_i}\left( lse\_square\_scale \cdot lse^2\right)  = 2\cdot lse\_square\_scale \cdot lse \cdot  softmax(x_i),

and we have the derivative of cross entropy loss with label smoothing

\frac{\partial L_{ce}}{\partial x_i} = softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}= \begin{cases} softmax(x_i) - \frac{\epsilon}{K},                        &  i \neq y \\
                                                   softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) &  i = y \end{cases}

where $\epsilon$ is label_smoothing and $K$ is the number of total classes. Thus, the derivative of total loss is

\begin{align}
\frac{\partial}{\partial x_i}L_{total} &= \frac{\partial}{\partial x_i}L_{ce} + \frac{\partial}{\partial x_i}z\_loss\
                                                     &= softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon)\delta_{k,y} +  2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i)\
                                                     &=\begin{cases} (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K}, & i \neq y\\
(1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K} -  (1 - \epsilon), & i = y \end{cases} 
\end{align}

Reference

PaLM: Scaling Language Modeling with Pathways Chameleon: Mixed-Modal Early-Fusion Foundation Models

Testing Done

benchmark gist neglectable error in speed benchmark.

This benchmark was done on my machine, which is probably not accurate.

liger ce: 66.123ms
Peak mem:  8.66200832

liger ce with zloss: 65.991ms
Peak mem:  8.66200832

liger ce with zloss with return zloss: 65.951ms
Peak mem:  8.662073856
Tcc0403 commented 1 week ago

Passed all tests. Ready for review!

Tcc0403 commented 4 days ago

Ignore OOM errors, the current custom CrossEntropyWithZLoss (torch.nn.module), as a ground truth implementation, has precision issue on gradients calculations with bfloat16 and reduction="sum".

LigerCrossEntropyLoss in this PR has no issue passing tests if comparing to flash-attn's CrossEntropyLoss. (gist)

Current goal is to make the custom torch implementation on par with flash-attn's.

Update: problems solved

Tcc0403 commented 3 days ago

All passed