lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.63k stars 739 forks source link

Save only the best model files #1787

Open ricardo-ervilha opened 3 months ago

ricardo-ervilha commented 3 months ago

Hey everyone,

I'm looking for a way to save only the best model during training, discarding the others as updates occur. Currently, I’m doing it like this:

checker = dde.callbacks.ModelCheckpoint(
    "model/model.ckpt", save_better_only=True, period=1000, verbose=1
)

losshistory, trainstate = model.train(
    iterations=ITERATIONS,
    batch_size=BATCH_SIZE,
    callbacks=[checker]
)

model.restore("model/model.ckpt-" + str(trainstate.best_step) + ".ckpt", verbose=1)

But this approach saves multiple files like:

model.ckpt-1000.ckpt.data-00000-of-00001, model.ckpt-2000.ckpt.data-00000-of-00001, ...

These extra files are unnecessary for me since I only need the weights of the best model at the end. Is there a way to always overwrite just one file that stores the weights for the best loss?

praksharma commented 2 months ago

I think you have to save the weights of model when the training loss in the current iteration is less than the previous iteration. You have to modify the source code.

If you are using adam, you need to modify the _train_sgd() defined in the Model class. Here you can find the for-loop of the training. You can create a new model name bestPINN and copy the weights using copy.deepcopy from model.net(state_dict()) where model.net() is the network you pass to dde.Model().

Now you can easily wrap this weight copy technique in a if-condition where you can compare the training loss in current and last iteration. You can compute the training loss by summing the entries in model.train_state.loss_train, which is a list.

Few years ago, I did the exact same thing with PINNs but in PyTorch (not deepxde). Here, you can find the relevant code in section named main.

ricardo-ervilha commented 2 months ago

Thank you very much for the response, your implementation in PyTorch helped me a lot !!