Open mtanneau opened 1 week ago
I agree it is unfortunate that simply calling BasicNeuralNet.train(...)
again doesn't just resume training. Indeed, calling BasicNeuralNet.train(...)
with non-empty trainer_kwargs
should give a warning when there is an existing trainer, since those kwargs won't be used (currently there is just a debug message):
https://github.com/AI4OPT/ML4OPF/blob/a778eb9ee9fb3d1fde4e5c66855cc07e16aa47f0/ml4opf/models/basic_nn/basic_nn.py#L88-L90
Anyway, to solve this problem I typically use this workaround of updating the trainer's internal max_epochs, then calling BasicNeuralNet.train()
again (with no kwargs). See the train(model: BasicNeuralNet, epochs: int)
function in tests/test_models.py
(specifically line 73):
https://github.com/AI4OPT/ML4OPF/blob/a778eb9ee9fb3d1fde4e5c66855cc07e16aa47f0/tests/test_models.py#L65-L86
The current training mechanism for BasicNeuralNet
is just a very thin wrapper over the PyTorch Lightning Trainer
.. I wouldn't be opposed to adding nice-to-haves like this to it in the future.
I used the starter code in the README to train a
ACBasicNeuralNet
for a few epochs.I then evaluated its performance, and after seeing the results, I would like to train it further. For instance, I would like to run it for another 16 epochs.
I tried the following
which immediately stopped as the max number of epochs was reached.
Calling
model.train
again with a higher number of epochs also terminates immediatelywith the same output
Is there a convenient way to continue/restart training, without loosing the current weights?