Closed Firegreat123 closed 2 years ago
Hey! Great question. There are a few ways that you can go about saving models.
The simplest way is to note that NT params are made of standard python datastructures (tuples and lists) along with JAX arrays, which will be serialized to standard numpy arrays. Thus, one option is to use pickle
to save the whole params tree, another is to flatten the tree, save using numpy.save
or numpy.savez
, and then save the tree structure using pickle.
For more details and sample code for this approach check out the thread over on Haiku: https://github.com/deepmind/dm-haiku/issues/18
Another option that's a little bit more complicated is to use jax2tf
to convert the model to tensorflow and then save the model as a SavedModel
. This has the advantage that it's hermetic (so that you don't need to keep the code to construct the model intact).
See here for more details: https://github.com/google/jax/tree/main/jax/experimental/jax2tf
In general, I would probably opt to save the model as numpy arrays during training and then if I wanted to have a longer term storage option to use the model on downstream tasks look into the SavedModel
pipeline.