NeuroDiffGym / neurodiffeq

A library for solving differential equations using neural networks based on PyTorch, used by multiple research groups around the world, including at Harvard IACS.
http://pypi.org/project/neurodiffeq/
MIT License
680 stars 89 forks source link

Examples of other callbacks #116

Closed troyrock closed 3 years ago

troyrock commented 3 years ago

Using L-BFGS leads to much faster optimizations however sometimes it is unstable. I can see that adding more training points seems to help so I am doing that but I would like to checkpoint the answer so that I can go back to a known good state and try again with a smaller learning rate if it becomes unstable. Do you have an example of using callbacks to do something like that? It seems like all of the machinery to do it is there but I'm not sure how to use it. Thanks again for your help.

shuheng-liu commented 3 years ago

Yes, it's simple. Say you want to do a checkpoint every 20 epochs:


from callbacks import CheckpointCallback
from callbacks import PeriodLocal

# The action is to do a checkpoint in the current ('.') directory
act_cb = CheckpointCallback(ckpt_dir='.')
# We only perform the above action every 20 local epochs
cond_cb = PeriodLocal(period=20)
# Link the action with the condition
cb = cond_cb.set_action_callback(act_cb)

solver = Solver(...)

solver.fit(max_epochs=100, callbacks=[..., cb])
shuheng-liu commented 3 years ago

On another note, we have been working on new features to save the internal states of solvers on the save-load-solver branch. It's not merged into the master yet; but when it is released, you should be able to do more in terms of continuing training a pretrained network, performing transfer learning, etc.

troyrock commented 3 years ago

Thank you Shuheng-Liu! I look forward to the new features!