whyNLP / LCKV

Layer-Condensed KV cache w/ 10 times larger batch size, fewer params and less computation. Dramatic speed up with better task performance. Accepted to ACL 2024.
https://arxiv.org/abs/2405.10637
139 stars 6 forks source link

Question about gradient checkpointing? #3

Closed 311dada closed 5 months ago

311dada commented 5 months ago

Congratulations on the excellent work!

When training large language models, we generally adopt the gradient checkpointing technique. Could you please help me turn on this technique in your code?

Thanks a lot!

why-in-Shanghaitech commented 5 months ago

Hi! Thank you for your question. I haven't used gradient checkpointing before, so I cannot ensure the correctness of the solution:

  1. Disable huggingface's warning about use_cache since we need it for further iterations (comment it out): https://github.com/whyNLP/LCKV/blob/503c82fcea86697a513814523ba5f80adaff390e/models/modeling_llama_opt.py#L1338-L1343

  2. Add --gradient_checkpointing in the bash script (replace this line or add below): https://github.com/whyNLP/LCKV/blob/503c82fcea86697a513814523ba5f80adaff390e/run_clm.sh#L31

I test the modified code on RTX3090 w/ config tinyllama_opt.json, batch size 16 w/ gradient checkpointing. The train loss of the first 20 steps is consistent with that w/o gradient checkpointing, batch size 4 and gradient accumulation 4.

I hope it could help. If it works, I'd appreciate it if you could add a simple PR so that more people could benefit from the gradient checkpointing feature.

311dada commented 5 months ago

Sorry for my late response. I will try it. Thanks for your suggestion.