young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

Feature request: Use Orbax for checkpointing. #80

Open OhadRubin opened 1 year ago

OhadRubin commented 1 year ago

wdyt about changing the current checkpointing to use orbax? https://github.com/google/orbax

young-geng commented 1 year ago

I think the current streaming checkpointer is quite good and I'm not planning to switch to orbax. Specifically, the streaming checkpointer can save and load checkpoints with minimal memory and temporary disk usage, which means that you can checkpoint models that does not fit in the memory or local disk of a single machine. Also as a personally preference, I want to stay away from complicated Google libraries as much as possible.

erfanzar commented 12 months ago

obrax do not support load streaming and sharding data or array across devices with pjit so I think the current checkpointing method that is being used right now is a smart move :)