linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
https://arxiv.org/pdf/2410.10989
BSD 2-Clause "Simplified" License
3.62k stars 214 forks source link

[feat] FusedLinearCrossEntropy support for Gemma2 #127

Closed yundai424 closed 3 weeks ago

yundai424 commented 3 months ago

🚀 The feature, motivation and pitch

FLCE needs special handling for the soft capping in gemma2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L1054

Alternatives

No response

Additional context

No response

troy1729 commented 3 months ago

take @yundai424 I would like to make an attempt to make it available.

I'm thinking this approach:

Can you assign it to me if this sounds okay?

qingquansong commented 3 months ago

@troy1729 Sounds reasonable to me. Assigned and feel free to kick off the implementation and ping us to discuss or review on any issues. Thank you!

troy1729 commented 3 months ago

Hi @qingquansong, I've made the changes but still have to add the tests hence kept the PR in draft stage. Might be a silly question, but we would want to have a triton kernel implementation for tanh/(any other non linearity) isn't it? Right now I've added torch.tanh callable. I'm sorry if this is obvious but thought to clarify

qingquansong commented 3 months ago

Hey @troy1729 , thanks for the question (no silly question) and fast kick off! I think 1) having certain triton functions operated on single element/block would be good in certain cases such as the silu function we have for swiglu that can be fused and used with other operations. Since in the end, we'd like to reduce element-wise operation overhead (like geglu/swiglu or Relu+ matmul etc.) rather than calling single one directly which will be same as calling torch.tanh especially after the torch compile. Also, check my comment 3 here and you'll find that implementing a single activation kernel would not be super helpful for you to fuse it with other operations especially in the backward pass. (isolated foward/backward functions could be helpful though)

2) The soft cap idea is mainly scaling + caping range to (-1, 1) so using tanh (which keeps both pos and neg info) so some other torch activations may not be good to use here (though I agree we may have some cases in the future that possibly call extra torch activation functions)

3) You may want to think about how the backprop is computed give this activation added on the logits. Since it's not as straight forward as just adding this activation, but you'll need to compute the grad_input (which is the grad of the hidden states) and the grad_weights.

4) One more option is to put this option inside the liger normal CE loss (also need to take care of the backward if enabled this option) and then outside the chunked calling of the kernel, in the flce kernel, you don't need to worry about the backprop.

In sum, my suggestion would be: implement the tanh option for now only + follow geglu backward to see how tanh gradient is computed with chain rule to device the equation and implement it here

Tcc0403 commented 1 month ago

I believe I've implemented softcap in cross entropy function correctly and the flce support for gemma2. But since gemma2 currently can't pass the test even without flce, do I need to find a way to pass the relevant convergence test (test_mini_models_no_logits.py)? cc @yundai424