ihh / bilby

Jax code for functional genomics ML
2 stars 0 forks source link

Save and load optimizer state in bilby #5

Open ihh opened 1 month ago

ihh commented 1 month ago

Currently, bilby auto-saves parameters during a training run and loads them back in if available on restart. However, it should also save the state of the optimizer. The Adam algorithm uses the concept of momentum from past gradient steps. This is currently lost when a training run is interrupted. It should be possible to serialize and deserialize the optax adam object that represents the state of this algorithm. This should then be autoloaded when bilby is started, if it exists. There should also be an option on the command line to disable this, or to specify the path to the serialized adam state.

ihh commented 1 month ago

See e.g. https://github.com/google-deepmind/optax/discussions/180 for a discussion of various ways to do this. The "official" way is to use orbax-checkpoint, but I found that rather cumbersome and (when I first tried it) I could not get all the orbax dependencies to work. But maybe it will work now. Alternatively, maybe there is a simpler way of just serializing the TrainState object.

The functions to load and save state probably need to live in state.py. The loader needs to be called somewhere around where the TrainState object is initialized in train.py. The saver probably should be called from the save_vars function.