Open wdlctc opened 1 month ago
Thank you @wdlctc ! We will review it and hopefully be able to push it in after our multimodal release! :)
Thank you @shimmyshimmer, for your review I addon detailed training info for reference:
For more implementation, you can refer our blog: https://wdlctc.github.io/mst.html or our paper https://www.arxiv.org/abs/2407.15892
If you need other fine-tuned settings, I can try it another time.
Rewrite it with unsloth fast_cross_entropy. We are surprised to find that integrated MST with unsloth not only improve memory behavior, but also introduce speedup.
The key difference: checkpointing hidden_state of LM-head (input) instead of checkpointing logits(output)
@wdlctc Thanks a lot again!! I'll test it and verify all losses match! Appreciate it!
10/14/2024: Resolve the conflicts with nightly branch
Sorry on the delay - was planning to add this together with Vision support :) It might take a few more days!
Oh lol I noticed I accidentally deleted this PR after I deleted the nightly branch - whoops so sorry!
Interesting so I looked through the paper and code, essentially you're proposing to essentially do gradient accumulation inside of each sequence length? Ie the first is normally chunking the CE Loss kernel amongst large columnar blocks, but you're suggesting the 2nd - chunking the rows itself.
And the trick is since it's row chunked, we also do not materialize the full logits but instead re-compute them on the fly?
Yes! key insight is full logits is too big especially when vocabulary size is large on LLAMA3(128k) and Gemma2(256), so re-compute them on the fly can effectively reduce memory(only compute one chunk at a time and discard previous chunk) and time(for offloading).
We do suggest do that row chunked, but you can also do both, row and col, as for LM-head and MLP the row and col(batch and seq) are independent. And it is effective as long context training would use local_batch_size=1.
Description This pull request introduces optimizations to the LLaMA model implementation, specifically targeting the language modeling head and forward pass. The main changes include:
Implement a custom _LM_head using torch.autograd.Function for more efficient forward and backward passes. Introduce a LMheadWarpper class to manage the custom LM head. Add minis_processing function to handle mini-batch processing of hidden states and labels. Modify the CausalLM_fast_forward function to use the new mini-batch processing and custom LM head.
Changes
Benefits
Testing Please ensure to test this implementation thoroughly, especially:
Performance comparison with the original implementation Correctness of loss calculation and gradient computation Memory usage across various input sizes