tensorflow / adanet

Fast and flexible AutoML with learning guarantees.
https://adanet.readthedocs.io
Apache License 2.0
3.47k stars 527 forks source link

Early stopping 'best-practice' using Adanet #151

Open nicholasbreckwoldt opened 4 years ago

nicholasbreckwoldt commented 4 years ago

I would like to apply early stopping using Adanet (0.8.0), in line with this version’s release notes: “Support subnetwork hooks requesting early stopping

Based on my understanding of the adanet.Estimator.train function’s ‘hooks’ input argument (i.e. a list of tf.train.SessionRunHook subclass instances used for callbacks inside the training loop), I imagine the following code configuration to be the appropriate implementation of the early stopping callback using Adanet (Note: I am using the adanet.TPUEstimator).

estimator = adanet.TPUEstimator(head=head, etc…)

early_stopping_hook = tf.estimator.experimental.stop_if_no_decrease_hook(estimator=estimator, metric_name="loss", max_steps_without_decrease=max_steps_without_decrease)

estimator.train(input_fn=input_fn, hooks=[early_stopping_hook], max_steps=max_steps)

Would this be the correct implementation? If so, what kind of behaviour would result from this? Ideally I would expect that early stopping be applied individually across each candidate subnetwork being trained within a given Adanet iteration to prevent overfitting on the data by each candidate, though I am not sure whether this is the case?

I note here (#112) it was previously proposed that iteratively tuning the ‘_max_iterationsteps’ (i.e. epochs) is a commonly applied strategy for preventing overfitting and finding the best Adanet model (at the time, early stopping per iteration was not supported). However, this seems to be expensive from a computational point of view since multiple Adanet models having different ‘_max_iterationsteps’ will need to be separately trained, before comparing and choosing the best Adanet model devoid of excessive overfitting. In addition, surely each candidate subnetwork in an Adanet iteration (whether through parameterisation and/or architecture difference) likely has a different optimal number of epochs before overfitting becomes relevant on a given dataset. So having the same number of train steps (i.e. '_max_iterationsteps') applicable to each candidate within an iteration is not ideal?

With this is mind, will the above implementation of early stopping hooks with Adanet (assuming this is indeed the correct implementation) now handle this automatically? I.e. given some large and non-optimal ‘_max_iterationsteps’, together with the early_stopping callback, will Adanet automatically reduce the number of iteration steps for a given candidate subnetwork according to criteria in the callback in order to prevent overfitting in the current iteration, before moving onto the next iteration and repeating for new set of candidate subnetworks?

In essence, I am looking to understand what would be considered ‘best-practice’ to address overfitting using Adanet and to find the best Adanet model.

Thanks in advance.