Open OhadRubin opened 1 year ago
I think the current streaming checkpointer is quite good and I'm not planning to switch to orbax. Specifically, the streaming checkpointer can save and load checkpoints with minimal memory and temporary disk usage, which means that you can checkpoint models that does not fit in the memory or local disk of a single machine. Also as a personally preference, I want to stay away from complicated Google libraries as much as possible.
obrax do not support load streaming and sharding data or array across devices with pjit so I think the current checkpointing method that is being used right now is a smart move :)
wdyt about changing the current checkpointing to use orbax? https://github.com/google/orbax