ml5js / ml5-library

Friendly machine learning for the web! 🤖
https://ml5js.org
Other
6.5k stars 901 forks source link

Is it possible to stop training neural network midway? #1397

Open vrunhofen opened 2 years ago

vrunhofen commented 2 years ago

I'm training various custom neural networks and I'm getting decent results. But I often feel the need to conditionally stop training before overfitting by looking at loss, val_loss, acc and val_acc etc provided in the callback to nn.train.

I would like to do something like this:

nn.train({
  epochs: 200,
  batchSize: 2500
}, 
(epoch, loss) => {
  if(loss.val_loss < 0.007 && loss.loss < 0.002){
    nn.stopTraining()
    // or nn.finishTraining()
  }
}, 
() => console.log('done'))

Is this possible? I couldn't find any documentation that suggests that this is possible. So I thought I'd ask. TensorFlow does allow this with an earlyStop setting. I even checked the source here on github to see if there was an undocumented function, but couldn't find anything.

lindapaiste commented 2 years ago

Good question! This is not something that we currently support but it shouldn't be hard to implement, since it's just passing down an argument to the TensorFlow model .fit() function.

The way that we are handling the whileTraining and finishedTraining callbacks right now is a bit fragile and something that I will probably "rip to shreds" eventually. 😝

I'm thinking that the simplest way to add this in to the current setup would be to allow a callbacks property on the options object. There would be no change required here since we are spreading all options. The TF callbacks can be an array and this whileTraining variable is already an array, so basically we need to optionally concatenate _options.callbacks here https://github.com/ml5js/ml5-library/blob/2c5cd1e9901d1d30ea592de01ca6ea28369888cb/src/NeuralNetwork/NeuralNetwork.js#L113-L119

Actually the more that I think about it, it probably makes more sense to handle it somewhere in here: https://github.com/ml5js/ml5-library/blob/2c5cd1e9901d1d30ea592de01ca6ea28369888cb/src/NeuralNetwork/index.js#L512-L544


That is not very user-friendly but it would get the job done by allowing you to pass in a tf.callbacks.earlyStopping() function. Though I'm not sure if that function meets your needs because it is based on looking at the delta from epoch to epoch (to see if a value stops improving) rather than looking at the value itself. You would call

nn.train({
  epochs: 200,
  batchSize: 2500,
  callbacks: tf.callbacks.earlyStopping({ monitor: 'val_loss' })
});

Or I think you could watch both?

nn.train({
  epochs: 200,
  batchSize: 2500,
  callbacks: [
    tf.callbacks.earlyStopping({ monitor: 'val_loss' }),
    tf.callbacks.earlyStopping({ monitor: 'loss' })
  ]
});

The sort of callback that you are proposing might be possible as a whileTraining callback function if we were to expose a stopTraining() method like you suggest. I have to look into the TF code a bit more because I am not seeing it in the documentation, but model.StopTraining is definitely a thing that exists. Actually you can access it right now:

nn.neuralNetwork.model.stopTraining = true

I wouldn't recommend having production code that relies on multiple levels of internal properties like that. But maybe play around with that, and if it works we can add in a shortcut function nn.stopTraining().

stopTraining() {
   this.neuralNetwork.model.stopTraining = true;
   // not sure about which callbacks would get called when aborting
}