tensorflow / decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Apache License 2.0
663 stars 110 forks source link

disable early-stopping does not work #28

Closed Howard-ll closed 3 years ago

Howard-ll commented 3 years ago

https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#disabling-the-validation-dataset-for-gbt

Tried to disable early-stopping & validation data but it seems like it does not work

Model generation without early-stopping & validation data

model = tfdf.keras.GradientBoostedTreesModel( num_trees=n_trees, growing_strategy="BEST_FIRST_GLOBAL", max_depth=depth, min_examples=1, shrinkage=learning_rate, categorical_algorithm="RANDOM", use_hessian_gain=True, validation_ratio=0.0, early_stopping=None, temp_directory=tmp_dir_name ) model.fit(x=x_selected, y=y)

error message

File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow_decision_forests/keras/core.py", line 780, in fit history = super(CoreModel, self).fit( File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1229, in fit callbacks.on_epoch_end(epoch, epoch_logs) File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py", line 435, in on_epoch_end callback.on_epoch_end(epoch, logs) File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow_decision_forests/keras/core.py", line 994, in on_epoch_end self._model._train_model() # pylint:disable=protected-access File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow_decision_forests/keras/core.py", line 915, in _train_model tf_core.train( File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow_decision_forests/tensorflow/core.py", line 494, in train return training_op.SimpleMLModelTrainer( File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow/python/util/tf_export.py", line 404, in wrapper return f(**kwargs) File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow_decision_forests/tensorflow/ops/training/op.py", line 512, in simple_ml_model_trainer _ops.raise_from_not_ok_status(e, name) File "/opt/conda/envs/tf2.5.0/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 6897, in raise_from_not_ok_status six.raise_from(core._status_to_exception(e.code, message), None) File "", line 3, in raise_from tensorflow.python.framework.errors_impl.UnknownError: TensorFlow: INVALID_ARGUMENT: Early stopping requires a validation set. Either set "validation_set_ratio" to be greater than 0, or disable early stopping. [Op:SimpleMLModelTrainer]

achoum commented 3 years ago

Thanks for the report.

TL;DR

Can you try replacing early_stopping=None with early_stopping="NONE".

Details

early_stopping is a categorical/string parameter (documentation) that support one of three possible values: NONE, MIN_LOSS_FINAL and LOSS_INCREASE. Setting early_stopping=None (or any other parameter to None) means the parameter unspecified. In this case, the library falls back to its default logic (which is LOSS_INCREASE in this case).

I'll make the documentation more explicit and add a warning when this case is detected :).

Howard-ll commented 3 years ago

Thanks for the report.

TL;DR

Can you try replacing early_stopping=None with early_stopping="NONE".

Details

early_stopping is a categorical/string parameter (documentation) that support one of three possible values: NONE, MIN_LOSS_FINAL and LOSS_INCREASE. Setting early_stopping=None (or any other parameter to None) means the parameter unspecified. In this case, the library falls back to its default logic (which is LOSS_INCREASE in this case).

I'll make the documentation more explicit and add a warning when this case is detected :).

Thanks a lot. I did not know it is a string input