jrbourbeau / dask-optuna

Scale Optuna with Dask
https://jrbourbeau.github.io/dask-optuna/
MIT License
35 stars 6 forks source link

Support Optuna + Joblib #10

Closed jrbourbeau closed 3 years ago

jrbourbeau commented 3 years ago

This PR updates DaskStorage to work directly with Optuna's Joblib internals. So now we can use optuna.Study.optimize instead of dask_optuna.optimize:

import optuna
import joblib
from dask.distributed import Client

import dask_optuna

def objective(trial):
    x = trial.suggest_uniform("x", -10, 10)
    return (x - 2) ** 2

if __name__ == "__main__":

    with Client(processes=True) as client:
        dask_storage = dask_optuna.DaskStorage("sqlite:///example.db")
        study = optuna.create_study(storage=dask_storage)
        with joblib.parallel_backend("dask"):
            study.optimize(objective, n_trials=300, n_jobs=-1)

Note that the DaskStorage.__reduce__ implementation here makes DaskStorage pickleable between nodes in a Dask cluster, but we don't include the underlying Optuna storage class when pickling a DaskStorage instance.

The benefit of this approach is that we avoid serializing / deserializing a potentially large storage class when unexpectedly pickling a DaskStorage instance. The downside is that DaskStorages don't work when using pickle to save to disk. To help mitigate this lack of being able to pickle DaskStorage instances to disk, we've also added a DaskStorage.get_base_storage() method which returns the underlying Optuna storage class which is able to be pickled and saved to a file.