NeuroDiffGym / neurodiffeq

A library for solving differential equations using neural networks based on PyTorch, used by multiple research groups around the world, including at Harvard IACS.
http://pypi.org/project/neurodiffeq/
MIT License
680 stars 89 forks source link

Model Saving & Reloading #84

Closed mariusmerkle closed 3 years ago

mariusmerkle commented 3 years ago

Hi,

I am studying how transfer learning can enhance the training of physics-informed neural networks. The NeuroDiffEq sparked my interest and I was wondering whether it is possible to

  1. save a trained model, i.e. the parameters of the network and its architecture

  2. reload the saved model and continue training from that non-random state.

shuheng-liu commented 3 years ago

Hi Marius. Yes, it's easy to do that as long as you know basic usage of PyTorch. For saving and loading model, here's a useful link. In the context of neurodiffeq, you want to perform the following procedures:

  1. Make sure you have the model class MyNetwork saved in some file model.py. Note that MyNetwork must be a subclass of torch.nn.Module

  2. Create one or more (depending on the number of functions you are solving for) MyNetwork instances:

    my_nets = [MyNetwork(...), MyNetwork(...), ...]
  3. Instantiate your solver and pass your model(s)

    solver = Solver1D(
    ...
    nets=my_nets,
    )
  4. Do the training and get your networks. Currently, neurodiffeq doesn't make a copy of the networks passed to it, so solver.nets is the same object as my_nets created earlier

    solver.fit(max_epochs=xxx, ...)
    my_nets = solver.nets # you can skip this step if you still have access to `my_nets` created earlier
  5. Save your model

    torch.save({f'net_{i}': net.state_dict() for i, net in enumerate(nets)}, YOUR_MODEL_PATH)
  6. In another script, instantiate your model using exactly the same architecture and load the weights

    loaded_nets = [MyNetwork(...), MyNetwork(...), ...]
    checkpoint = torch.load(YOUR_MODEL_PATH)
    for i, net in enumerate(loaded_nets):
    net.load_state_dict(checkpoint[f'net_{i}'])
  7. Redo step 3~4, but change my_nets to loaded_nets

mariusmerkle commented 3 years ago

That sounds great! And it is possible to use both Adam optimiser and L-BFGS/L-BFGS-B, right?

shuheng-liu commented 3 years ago

Most optimizers are currently supported, except LBFGS, which is a little tricky (see #83). Luckily, we seem to have a solution proposed just now. Yet, we still need to run the tests.

I'm not familiar with L-BFGS-B, but it appears that this optimizer has not been implemented in PyTorch (see here). So currently, you can't use L-BFGS-B without implementing it yourself.