Closed 311dada closed 5 months ago
Hi! Thank you for your question. I haven't used gradient checkpointing before, so I cannot ensure the correctness of the solution:
Disable huggingface's warning about use_cache
since we need it for further iterations (comment it out):
https://github.com/whyNLP/LCKV/blob/503c82fcea86697a513814523ba5f80adaff390e/models/modeling_llama_opt.py#L1338-L1343
Add --gradient_checkpointing
in the bash script (replace this line or add below):
https://github.com/whyNLP/LCKV/blob/503c82fcea86697a513814523ba5f80adaff390e/run_clm.sh#L31
I test the modified code on RTX3090 w/ config tinyllama_opt.json
, batch size 16 w/ gradient checkpointing. The train loss of the first 20 steps is consistent with that w/o gradient checkpointing, batch size 4 and gradient accumulation 4.
I hope it could help. If it works, I'd appreciate it if you could add a simple PR so that more people could benefit from the gradient checkpointing feature.
Sorry for my late response. I will try it. Thanks for your suggestion.
Congratulations on the excellent work!
When training large language models, we generally adopt the gradient checkpointing technique. Could you please help me turn on this technique in your code?
Thanks a lot!