tensorflow / skflow

Simplified interface for TensorFlow (mimicking Scikit Learn) for Deep Learning
Apache License 2.0
3.18k stars 439 forks source link

Add early stopping and reporting based on validation data #104

Closed dansbecker closed 8 years ago

dansbecker commented 8 years ago

This PR allows a user to specify a validation dataset that are used for early stopping (and reporting). The PR was created to address issue 85

I made changes in 3 places.

  1. The trainer now takes a dictionary containing the validation data (in the same format as the output of the data feeder's get_dict_fn).
  2. The fit method now takes arguments for val_X and val_y. It converts these into the correct format for the trainer.
  3. The example file digits.py now uses early stopping, by supplying val_X and val_y.

I can add early stopping to other examples if this approach looks good, though their behavior should not otherwise be affected by the current PR.

dansbecker commented 8 years ago

The original PR fails the linter because TensorFlowTrainer.train has too many branches (14). I think this fairly represents a fair amount of complexity that's creeping into the train function.

Some of the reporting logic is currently in the train and some is in _print_report.

I'd like to refactor reporting so all of the logic and printing are in a TensorFlowTrain.report method. I think this is a worthwhile change aside from the current linter error. But given that this change will resolve that error, do you mind I include it in the current PR?

terrytangyuan commented 8 years ago

@dansbecker This PR looks good to me. Regarding the lint error, you can disable locally by adding a comment like this # pylint: disable=too-many-branches to resolve it for now but we may need to refactor at some point. @ilblackdragon What do you think? Could you check this PR?

dansbecker commented 8 years ago

@terrytangyuan I updated the PR with the line you sent. If you are open to the refactor (moving the output & reporting into a separate method), and I'll submit it as a separate PR.

codecov-io commented 8 years ago

Current coverage is 90.49%

Merging #104 into master will increase coverage by +0.91% as of a329538

@@            master    #104   diff @@
======================================
  Files           26      27     +1
  Stmts          989    1042    +53
  Branches       160     159     -1
  Methods          0       0       
======================================
+ Hit            886     943    +57
+ Partial         49      47     -2
+ Missed          54      52     -2

Review entire Coverage Diff as of a329538

Powered by Codecov. Updated on successful CI builds.

ilblackdragon commented 8 years ago

Refactoring sounds good for the trainer.

I actually think that reporting may be even worth moving out of the trainer at all and making similar to "Monitor" interface (e.g. something you attach to the progress of training and can be customized) you mentioned in another thread. Then we can have a simple ConsoleMonitor that would produce current behaviour, and also, would allow extending with storing info to a local database or any other reporting tooling we want to add over time.

On the current PR: I'm a bit negative about adding more parameters to fit() method because this is not supported by all the other sci-kit learn tooling (like pipelines, cross validation etc).

Now scikit learn does have some usage for monitors, for an example and they use it for early stopping as well.

Maybe we can have a ValidationMonitor(X_val, y_val) that would return true if model converged.

@dansbecker What do you think?

dansbecker commented 8 years ago

@ilblackdragon Now that you bring it up, adding arguments to the fit() method is obviously the wrong choice, for the reason you mention.

In the related issue, I also considered the option to create the monitor you mentioned, or to add an argument to the constructor of the TensorFlowEstimatorthat specifies what fraction of data is set aside for early stopping.

The monitor seems clearly more modular. Adding a frac_used_for_early_stopping at model specification seems to me to be an easier user interface.

If you confirm that you prefer the monitor, I'll do that.

ilblackdragon commented 8 years ago

Yeah, I saw your suggestion about a monitor and after thinking and looking at scikit learn I think it's a better way to go. We probably can also remove a lot of options from estimators if we move everything to monitors (like verbosity, savers and etc) - it's quite a few already and adding more won't scale.

So yes, on the monitors! :) Maybe let's make a monitors folder, because I think they will become a core part of APIs and there can be a lot of them.

Thanks for doing this, really appreciate your time and looking into options / trying things out!

dansbecker commented 8 years ago

Just committed the first version of the refactor. This appears to work with early stopping on validation data. I still need to

So, it's definitely not ready yet... but there's an update showing the current status.

terrytangyuan commented 8 years ago

Looks great! BTW, you can reproduce the Travis error by nosetests for unit tests and pylint skflow for lint check.

dansbecker commented 8 years ago

@ilblackdragon @terrytangyuan This is ready for CR.

ilblackdragon commented 8 years ago

@dansbecker In general - looks pretty good, I like how trainer become a lot cleaner. The only thing is to make monitor less validation centric. (e.g. one can make a monitor that would be maybe Ipython console or reporting into database, etc)

dansbecker commented 8 years ago

I agree that this ended up validation centric (in part because I originally intended to create this PR just to implement validation.) Do you think it's reasonable to address that as part of a subsequent PR, when we start adding more classical monitoring-type functionality?

ilblackdragon commented 8 years ago

For this one, fix the linters and I think it's ok to merge it in and keep iterating (I don't think people will jump on using this API of monitors right away).

dansbecker commented 8 years ago

Fixed the lint. Let me know whether I should be squashing that type of commit.

I think this PR closes #85