jrbourbeau / dask-optuna

Scale Optuna with Dask
https://jrbourbeau.github.io/dask-optuna/
MIT License
35 stars 6 forks source link

Feature request: Implement XGBoostPruningCallback #21

Closed daria-dc closed 3 years ago

daria-dc commented 3 years ago

Hi,

thanks for the package, I find it really useful! I was using dask-optuna to tune the hyperparameters of an xgboost model on multiple GPUS, and find that, when I add the line:

pruning_callback = optuna.integration.XGBoostPruningCallback(trial, "validation-mae")

the code is not working anymore and throwing serialization errors.

To me it seems that this optuna feature is not integrating with dask-optuna. It would be nice to have this in the future!

jrbourbeau commented 3 years ago

Thanks for raising and issue @daria-dc, glad to hear you've found dask-optuna useful!

Are you able to post a minimal reproducible example (see https://blog.dask.org/2018/02/28/minimal-bug-reports)? I attempted to come up with an example by modifying this XGBoostPruningCallback example from Optuna to use dask_optuna.DaskStorage and "validation-mae" in XGBoostPruningCallback (see the example snippet below), however this is example runs successfully without throwing any serialization errors. How does this differ from your use case?

Example: ```python """ Optuna example that demonstrates a pruner for XGBoost. In this example, we optimize the validation accuracy of cancer detection using XGBoost. We optimize both the choice of booster model and their hyperparameters. Throughout training of models, a pruner observes intermediate results and stop unpromising trials. You can run this example as follows: $ python xgboost_integration.py """ import numpy as np import sklearn.datasets import sklearn.metrics from sklearn.model_selection import train_test_split import xgboost as xgb import optuna import dask_optuna import dask.distributed # FYI: Objective functions can take additional arguments # (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args). def objective(trial): data, target = sklearn.datasets.load_boston(return_X_y=True) train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25) dtrain = xgb.DMatrix(train_x, label=train_y) dvalid = xgb.DMatrix(valid_x, label=valid_y) param = { "silent": 1, "objective": "reg:squarederror", "eval_metric": "mae", "booster": trial.suggest_categorical("booster", ["gbtree", "gblinear", "dart"]), "lambda": trial.suggest_float("lambda", 1e-8, 1.0, log=True), "alpha": trial.suggest_float("alpha", 1e-8, 1.0, log=True), } if param["booster"] == "gbtree" or param["booster"] == "dart": param["max_depth"] = trial.suggest_int("max_depth", 1, 9) param["eta"] = trial.suggest_float("eta", 1e-8, 1.0, log=True) param["gamma"] = trial.suggest_float("gamma", 1e-8, 1.0, log=True) param["grow_policy"] = trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"]) if param["booster"] == "dart": param["sample_type"] = trial.suggest_categorical("sample_type", ["uniform", "weighted"]) param["normalize_type"] = trial.suggest_categorical("normalize_type", ["tree", "forest"]) param["rate_drop"] = trial.suggest_float("rate_drop", 1e-8, 1.0, log=True) param["skip_drop"] = trial.suggest_float("skip_drop", 1e-8, 1.0, log=True) # Add a callback for pruning. pruning_callback = optuna.integration.XGBoostPruningCallback(trial, "validation-mae") bst = xgb.train(param, dtrain, evals=[(dvalid, "validation")], callbacks=[pruning_callback]) preds = bst.predict(dvalid) pred_labels = np.rint(preds) score = sklearn.metrics.mean_squared_error(valid_y, pred_labels) return score if __name__ == "__main__": # Create a local dask cluster with 4 workers with dask.distributed.Client(n_workers=4): study = optuna.create_study( storage=dask_optuna.DaskStorage(), pruner=optuna.pruners.MedianPruner(n_warmup_steps=5), direction="minimize" ) study.optimize(objective, n_trials=5) print(study.best_trial) ```
daria-dc commented 3 years ago

Thanks for the quick answer @jrbourbeau !

My minimal example would look like this:


import pandas as pd
import numpy as np
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from joblib import parallel_backend
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
import xgboost as xgb
import optuna
import dask_optuna

cluster = LocalCUDACluster(memory_limit=0)
client = Client(cluster)
workers = client.has_what().keys()
n_workers = len(workers)

X = pd.DataFrame(np.random.rand(10*10).reshape(10,10))
y = pd.Series(np.random.rand(10))

def objective(trial, X, y):

    X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.2)
    X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=0.2)

    dtrain = xgb.DMatrix(X_train, y_train)
    dval = xgb.DMatrix(X_val, y_val)
    dtest = xgb.DMatrix(X_test, y_test)

    params = {
              'objective':'reg:squarederror',
              'eval_metric': 'mae',              
              'booster': 'gbtree',
              'tree_method': 'gpu_hist'
              }

    num_boost_round = trial.suggest_int("num_boost_round", 10, 50, 10)

    pruning_callback = optuna.integration.XGBoostPruningCallback(trial, "validation-mae")

    trained_model = xgb.train(
                     params,
                     dtrain=dtrain,
                     evals=[(dval, 'validation')],
                     early_stopping_rounds=10,
                     num_boost_round = num_boost_round,
                     callbacks=[pruning_callback]
                     )

    prediction = trained_model.predict(dtest)
    score = mean_absolute_error(y_test, prediction)
    return score

storage = dask_optuna.DaskStorage('sqlite:///example.db')
study = optuna.create_study(sampler = optuna.samplers.TPESampler(),
                                study_name = 'example',
                                direction = "minimize",
                                storage = storage,
                                load_if_exists=True)

with parallel_backend("dask"):
        study.optimize(lambda trial: objective(trial, X=X, y=y),
                           n_trials=10,
                           n_jobs=n_workers)

And the error I receive:

TypeError: ('Could not serialize object of type tuple.', "(<function apply at 0x7fbcdca1a4c0>, batch_of__optimize_sequential_1_calls, [], {'tasks': 
[(<function _optimize_sequential at 0x7fbdedec08b0>, [<optuna.study.Study object at 0x7fbd632d52b0>, <function <lambda> at 0x7fbcdc42f430>, 1, None,
 (), None, False], {'reseed_sampler_rng': True, 'time_start': datetime.datetime(2020, 11, 24, 8, 43, 12, 883501), 'progress_bar': None})]})")       
distributed.comm.utils - ERROR - ('Could not serialize object of type tuple.', "(<function apply at 0x7fbcdca1a4c0>, batch_of__optimize_sequential_1
_calls, [], {'tasks': [(<function _optimize_sequential at 0x7fbdedec08b0>, [<optuna.study.Study object at 0x7fbd632d52b0>, <function <lambda> at 0x7
fbcdc42f430>, 1, None, (), None, False], {'reseed_sampler_rng': True, 'time_start': datetime.datetime(2020, 11, 24, 8, 43, 12, 883501), 'progress_ba
r': None})]})")                                                                                                                                     Traceback (most recent call last):  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/distributed/comm/utils.py", line 34, in _to_frames
    protocol.dumps(
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/distributed/protocol/core.py", line 50, in dumps
    data = {
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/distributed/protocol/core.py", line 51, in <dictcomp>
    key: serialize(  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/distributed/protocol/serialize.py", line 277, in serialize    raise TypeError(msg, str(x)[:10000])

The code works like this if I don't include the pruning callback. Can you reproduce this?

jrbourbeau commented 3 years ago

Ah gotcha, thanks for providing a nice example @daria-dc! In quickly putting together my initial example I forgot to select the joblib Dask backend. When I add with joblib.parallel_backend("dask"): I'm able to reproduce the serialization errors you're seeing. I'll set aside some time later to look further into this issue

jrbourbeau commented 3 years ago

Looking at the tuple which couldn't be serialized, it turns out the issue is cloudpickle cannot serialize the objective function when optuna.integration.XGBoostPruningCallback is used inside of it. It appears the underlying serialization issue isn't related to the XGBoostPruningCallback specifically, but rather it has to do with the _IntegrationModule class Optuna uses to support lazy importing of modules:

In [1]: import cloudpickle

In [2]: import optuna

In [3]: def test_func(trial):
   ...:     x = optuna.integration.XGBoostPruningCallback(trial, "validation-mae")
   ...:

In [4]: cloudpickle.dumps(test_func)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-460175ced07d> in <module>
----> 1 cloudpickle.dumps(test_func)

~/miniforge3/envs/dask-optuna/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py in dumps(obj, protocol, buffer_callback)
     71                 file, protocol=protocol, buffer_callback=buffer_callback
     72             )
---> 73             cp.dump(obj)
     74             return file.getvalue()
     75

~/miniforge3/envs/dask-optuna/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py in dump(self, obj)
    561     def dump(self, obj):
    562         try:
--> 563             return Pickler.dump(self, obj)
    564         except RuntimeError as e:
    565             if "recursion" in e.args[0]:

TypeError: cannot pickle '_IntegrationModule' object

As a workaround, you can bypass lazy importing in the objective function by importing XGBoostPruningCallback earlier in your script. So instead of

import optuna

def objective(trial, ...):
    ...
    pruning_callback = optuna.integration.XGBoostPruningCallback(trial, "validation-mae")
    ...

do

from optuna.integration import XGBoostPruningCallback

def objective(trial, ...):
    ...
    pruning_callback = XGBoostPruningCallback(trial, "validation-mae")
    ...

and that should fix the serialization issue

jrbourbeau commented 3 years ago

It appears this has already been reported upstream in cloudpickle (xref https://github.com/cloudpipe/cloudpickle/issues/397)

daria-dc commented 3 years ago

Thanks @jrbourbeau, your solution is working for me!

Going to close this issue.