tensorflow / skflow

Simplified interface for TensorFlow (mimicking Scikit Learn) for Deep Learning
Apache License 2.0
3.18k stars 439 forks source link

Add support for validation sets #85

Closed untom closed 8 years ago

untom commented 8 years ago

It would be nice if skflow had some support for validation sets to be used for early stopping and monitor validation set loss during training. This could be realized failry easily by adding a fraction_validationset to the TensorFlowEstimator. Within fit, the given training set could then be split into two parts.

ilblackdragon commented 8 years ago

The validation set support is a good idea. One concern I have is to how to fit with sklearn interface when a user wants to pass his specific validation set. Because usually you split your dataset 3 ways - train, validation and test and use validation for hyperparameter search.

makseq commented 8 years ago

I made it as delegate validation function passing through fit(..., cross_valid_fn=my_func()). my_func(): return calc_rmse(test_data) It's very useful because user can adjust his my_func() as he wants.

dansbecker commented 8 years ago

I was going to take a look at this (though I'm still learning my way around skflow).

I see an early_stopping_rounds argument in the estimators (TensorFlowEstimator and its derived classes), and the argument is passed to the TensorFlowTrainer, which appears to implement early stopping logic.

Is the current early stopping logic different from what's suggested in this issue? I'll pursue this issue further if I can understand how it differs from the current early stopping.

ilblackdragon commented 8 years ago

@dansbecker Thanks for taking a look! Right now early stopping is done on training loss - e.g. if training converged, model stops before a number of required steps.

On the other hand, using validation set is another option. But so far, in examples we were implementing it this way: https://github.com/tensorflow/skflow/blob/master/examples/resnet.py#L148 So one way I was thinking this issue can be addressed is by making something like skflow.train(estimator, X_train, y_train, X_valid, y_valid, metric?) that does this loop and also does stopping if validation metric stops improving.

dansbecker commented 8 years ago

Thanks @ilblackdragon. That makes sense.

On the issue of using the same data for early stopping and for hyperparameter search: Three options come to mind.

  1. Let the user pass in two sets of data to the fit method. One set for training the network, the second (which @untom calls validation in this issue) for determining when to stop training. As I think you mentioned, this raises the question of whether the same validation data can also being used for hyperparameter search.

    It feels reasonable if I think of the number of steps as a hyperparameter, in which case we'd expect the same data to determine number of steps as is used to determine the other hyperparameters.

    However, I think it is least consistent with the sklearn interface.

    Incidentally, this is the approach that keras uses.

  2. Let the user specify an argument for what fraction of the training data is used only to determine early stopping (and not used to set weights). In this approach, they don't specify a separate data set to be used for early stopping.

    This option may confuse some users, who expect all data in the training set to always be used to determine network weights. However, it's consistent with the sklearn interface than the first option, and I think most users will find it the easiest to use.

  3. Let the user create a monitor object that tells the network when to stop. The user specifies the relevant data when creating that object, and that monitor is an optional argument to the fit method. This is how I interpret what @makseq described above, and this post describes its use with sklearn's GradientBoostingClassifier.

Thoughts?

dansbecker commented 8 years ago

@ilblackdragon :Those three above are in addition to the one you mentioned skflow.train(estimator, X_train, y_train, X_valid, y_valid, metric?)

My inclination is towards either your suggestion, or 2 or 3 in the note above.