tensorflow / decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Apache License 2.0
663 stars 110 forks source link

Automated hyper-parameter tuning broken for Regression #116

Closed rishiraj closed 2 years ago

rishiraj commented 2 years ago

Hyper-parameter tuning by specifying the tuner constructor argument of the model currently only works for classification tasks. For regression tasks, where the default loss function is MSE or RMSE, it breaks. At the end of the tuning trials, there is a hessian optimization phase. However Hessian learning is disabled for GBT regression with MSE loss. Hence the following error is shown:

---------------------------------------------------------------------------
UnknownError                              Traceback (most recent call last)
/tmp/ipykernel_19/3508847010.py in <module>
      1 # Tune the model. Notice the `tuner=tuner`.
      2 gb_model = tfdf.keras.GradientBoostedTreesModel(tuner=tuner, task=tfdf.keras.Task.REGRESSION)
----> 3 gb_model.fit(x=train_tfds, verbose=2)

/opt/conda/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py in fit(self, x, y, callbacks, verbose, validation_steps, validation_data, sample_weight, steps_per_epoch, class_weight, **kwargs)
   1506         sample_weight=sample_weight,
   1507         steps_per_epoch=steps_per_epoch,
-> 1508         class_weight=class_weight)
   1509 
   1510   @base_tracking.no_automatic_dependency_tracking

/opt/conda/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py in _fit_implementation(self, x, y, verbose, callbacks, sample_weight, validation_data, validation_steps, steps_per_epoch, class_weight)
   1821       tf_logging.info("Training model...")
   1822 
-> 1823     self._train_model(cluster_coordinator=coordinator)
   1824 
   1825     if self._verbose >= 1:

/opt/conda/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py in _train_model(self, cluster_coordinator)
   2301             deployment_config=deployment_config,
   2302             try_resume_training=self._try_resume_training,
-> 2303             has_validation_dataset=self._has_validation_dataset)
   2304 
   2305       else:

/opt/conda/lib/python3.7/site-packages/tensorflow_decision_forests/tensorflow/core.py in train(input_ids, label_id, weight_id, model_id, learner, task, generic_hparms, ranking_group, uplift_treatment, training_config, deployment_config, guide, model_dir, keep_model_in_resource, try_resume_training, has_validation_dataset)
    866       guide=guide.SerializeToString(),
    867       has_validation_dataset=has_validation_dataset,
--> 868       use_file_prefix=use_file_prefix)
    869 
    870 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/tf_export.py in wrapper(*args, **kwargs)
    398           'Please pass these args as kwargs instead.'
    399           .format(f=f.__name__, kwargs=f_argspec.args))
--> 400     return f(**kwargs)
    401 
    402   return tf_decorator.make_decorator(f, wrapper, decorator_argspec=f_argspec)

<string> in simple_ml_model_trainer(feature_ids, label_id, weight_id, model_id, model_dir, learner, hparams, task, training_config, deployment_config, guide, has_validation_dataset, use_file_prefix, name)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   7162 def raise_from_not_ok_status(e, name):
   7163   e.message += (" name: " + name if name is not None else "")
-> 7164   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   7165 
   7166 

UnknownError: TensorFlow: INVALID_ARGUMENT: Loss does not support hessian optimization [Op:SimpleMLModelTrainer]
rstz commented 2 years ago

Hi, thank you for flagging this issue. Could you please tell us which arguments you're using for the hyperparameter tuner or, even better, come up with a repro in the form of a colab?

rishiraj commented 2 years ago

Just the exact same arguments and codes that are shown in the official tutorial colab. Only difference is I used it for regression instead of classification with task=tfdf.keras.Task.REGRESSION

tuner = tfdf.tuner.RandomSearch(num_trials=50)
tuner.choice("min_examples", [2, 5, 7, 10])
tuner.choice("categorical_algorithm", ["CART", "RANDOM"])
local_search_space = tuner.choice("growing_strategy", ["LOCAL"])
local_search_space.choice("max_depth", [3, 4, 5, 6, 8])
global_search_space = tuner.choice("growing_strategy", ["BEST_FIRST_GLOBAL"], merge=True)
global_search_space.choice("max_num_nodes", [16, 32, 64, 128, 256])

# The argument that I think is throwing the error
tuner.choice("use_hessian_gain", [True, False])

tuner.choice("shrinkage", [0.02, 0.05, 0.10, 0.15])
tuner.choice("num_candidate_attributes_ratio", [0.2, 0.5, 0.9, 1.0])
rishiraj commented 2 years ago

I can confirm that commenting out tuner.choice("use_hessian_gain", [True, False]) doesn’t throw this error anymore.

rstz commented 2 years ago

Great! Indeed, setting tuner.choice("use_hessian_gain", [True, False]) explicitly instructs the tuner to try out both True and False for the setting use_hessian_gain, despite hessian gain being unavailable for MSE.