Open Tcc0403 opened 1 week ago
Passed all tests. Ready for review!
Ignore OOM errors, the current custom CrossEntropyWithZLoss (torch.nn.module), as a ground truth implementation, has precision issue on gradients calculations with bfloat16 and reduction="sum".
LigerCrossEntropyLoss in this PR has no issue passing tests if comparing to flash-attn's CrossEntropyLoss. (gist)
Current goal is to make the custom torch implementation on par with flash-attn's.
Update: problems solved
All passed
Summary
This PR aims to resolve #197
Implemented z loss in LigerCrossEntropy.
note:
lse_square_scale
not exposed at flce yet, having issues passing the tests.Details
For loss:
We can use $m = max(X_i)$ and $d = \sum e^{X_i - m}$, obtained from online softmax algorithm, to calculate $lse$ directly.
For gradients:
First, we calculate the derivative of lse
Then we can obtain the derivative of z_loss by chain rule.
and we have the derivative of cross entropy loss with label smoothing
where $\epsilon$ is label_smoothing and $K$ is the number of total classes. Thus, the derivative of total loss is
Reference
PaLM: Scaling Language Modeling with Pathways Chameleon: Mixed-Modal Early-Fusion Foundation Models
Testing Done
benchmark gist neglectable error in speed benchmark.
This benchmark was done on my machine, which is probably not accurate.
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence