unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.59k stars 1.3k forks source link

Introduce MsT technologies into unsloth to extend sequence length #1082

Open wdlctc opened 1 month ago

wdlctc commented 1 month ago

Description This pull request introduces optimizations to the LLaMA model implementation, specifically targeting the language modeling head and forward pass. The main changes include:

Implement a custom _LM_head using torch.autograd.Function for more efficient forward and backward passes. Introduce a LMheadWarpper class to manage the custom LM head. Add minis_processing function to handle mini-batch processing of hidden states and labels. Modify the CausalLM_fast_forward function to use the new mini-batch processing and custom LM head.

Changes

Benefits

Testing Please ensure to test this implementation thoroughly, especially:

Performance comparison with the original implementation Correctness of loss calculation and gradient computation Memory usage across various input sizes

shimmyshimmer commented 1 month ago

Thank you @wdlctc ! We will review it and hopefully be able to push it in after our multimodal release! :)

wdlctc commented 1 month ago

Thank you @shimmyshimmer, for your review I addon detailed training info for reference:

  1. standard training with slightly better loss performance: unsloth(1.192900) vs unsloth+MST(1.165600)
  2. 2x long sequence length on LLAMA2: unsloth OOM at 25k, unsloth work at 12k, unsloth+MST work at 25k

For more implementation, you can refer our blog: https://wdlctc.github.io/mst.html or our paper https://www.arxiv.org/abs/2407.15892

If you need other fine-tuned settings, I can try it another time.

wdlctc commented 1 month ago

Rewrite it with unsloth fast_cross_entropy. We are surprised to find that integrated MST with unsloth not only improve memory behavior, but also introduce speedup.

The key difference: checkpointing hidden_state of LM-head (input) instead of checkpointing logits(output)

danielhanchen commented 1 month ago

@wdlctc Thanks a lot again!! I'll test it and verify all losses match! Appreciate it!

wdlctc commented 1 month ago

10/14/2024: Resolve the conflicts with nightly branch

danielhanchen commented 1 month ago

Sorry on the delay - was planning to add this together with Vision support :) It might take a few more days!

danielhanchen commented 1 month ago

Oh lol I noticed I accidentally deleted this PR after I deleted the nightly branch - whoops so sorry!

danielhanchen commented 1 month ago

Interesting so I looked through the paper and code, essentially you're proposing to essentially do gradient accumulation inside of each sequence length? Ie the first is normally chunking the CE Loss kernel amongst large columnar blocks, but you're suggesting the 2nd - chunking the rows itself. image

And the trick is since it's row chunked, we also do not materialize the full logits but instead re-compute them on the fly?

wdlctc commented 1 month ago

Yes! key insight is full logits is too big especially when vocabulary size is large on LLAMA3(128k) and Gemma2(256), so re-compute them on the fly can effectively reduce memory(only compute one chunk at a time and discard previous chunk) and time(for offloading).

We do suggest do that row chunked, but you can also do both, row and col, as for LM-head and MLP the row and col(batch and seq) are independent. And it is effective as long context training would use local_batch_size=1.

image