google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Question about training model #150

Closed Firegreat123 closed 2 years ago

Firegreat123 commented 2 years ago
sschoenholz commented 2 years ago

Hey! Great question. There are a few ways that you can go about saving models.

The simplest way is to note that NT params are made of standard python datastructures (tuples and lists) along with JAX arrays, which will be serialized to standard numpy arrays. Thus, one option is to use pickle to save the whole params tree, another is to flatten the tree, save using numpy.save or numpy.savez, and then save the tree structure using pickle.

For more details and sample code for this approach check out the thread over on Haiku: https://github.com/deepmind/dm-haiku/issues/18

Another option that's a little bit more complicated is to use jax2tf to convert the model to tensorflow and then save the model as a SavedModel. This has the advantage that it's hermetic (so that you don't need to keep the code to construct the model intact).

See here for more details: https://github.com/google/jax/tree/main/jax/experimental/jax2tf

In general, I would probably opt to save the model as numpy arrays during training and then if I wanted to have a longer term storage option to use the model on downstream tasks look into the SavedModel pipeline.

Firegreat123 commented 2 years ago