Open vrunhofen opened 2 years ago
Good question! This is not something that we currently support but it shouldn't be hard to implement, since it's just passing down an argument to the TensorFlow model .fit()
function.
The way that we are handling the whileTraining
and finishedTraining
callbacks right now is a bit fragile and something that I will probably "rip to shreds" eventually. 😝
I'm thinking that the simplest way to add this in to the current setup would be to allow a callbacks
property on the options
object. There would be no change required here since we are spreading all options. The TF callbacks
can be an array and this whileTraining
variable is already an array, so basically we need to optionally concatenate _options.callbacks
here
https://github.com/ml5js/ml5-library/blob/2c5cd1e9901d1d30ea592de01ca6ea28369888cb/src/NeuralNetwork/NeuralNetwork.js#L113-L119
Actually the more that I think about it, it probably makes more sense to handle it somewhere in here: https://github.com/ml5js/ml5-library/blob/2c5cd1e9901d1d30ea592de01ca6ea28369888cb/src/NeuralNetwork/index.js#L512-L544
That is not very user-friendly but it would get the job done by allowing you to pass in a tf.callbacks.earlyStopping()
function. Though I'm not sure if that function meets your needs because it is based on looking at the delta from epoch to epoch (to see if a value stops improving) rather than looking at the value itself. You would call
nn.train({
epochs: 200,
batchSize: 2500,
callbacks: tf.callbacks.earlyStopping({ monitor: 'val_loss' })
});
Or I think you could watch both?
nn.train({
epochs: 200,
batchSize: 2500,
callbacks: [
tf.callbacks.earlyStopping({ monitor: 'val_loss' }),
tf.callbacks.earlyStopping({ monitor: 'loss' })
]
});
The sort of callback that you are proposing might be possible as a whileTraining
callback function if we were to expose a stopTraining()
method like you suggest. I have to look into the TF code a bit more because I am not seeing it in the documentation, but model.StopTraining
is definitely a thing that exists. Actually you can access it right now:
nn.neuralNetwork.model.stopTraining = true
I wouldn't recommend having production code that relies on multiple levels of internal properties like that. But maybe play around with that, and if it works we can add in a shortcut function nn.stopTraining()
.
stopTraining() {
this.neuralNetwork.model.stopTraining = true;
// not sure about which callbacks would get called when aborting
}
I'm training various custom neural networks and I'm getting decent results. But I often feel the need to conditionally stop training before overfitting by looking at loss, val_loss, acc and val_acc etc provided in the callback to nn.train.
I would like to do something like this:
Is this possible? I couldn't find any documentation that suggests that this is possible. So I thought I'd ask. TensorFlow does allow this with an earlyStop setting. I even checked the source here on github to see if there was an undocumented function, but couldn't find anything.