jrbourbeau / dask-optuna

Scale Optuna with Dask
https://jrbourbeau.github.io/dask-optuna/
MIT License
35 stars 6 forks source link

Add dask_optuna.optimize function #4

Closed jrbourbeau closed 3 years ago

jrbourbeau commented 3 years ago

This wraps batch submitting trials to a Dask cluster in a utility dask_optuna.optimize function with a similar signature as optuna.Study.optimize:

import optuna
import dask.distributed
import dask_optuna

def objective(trial):
    x = trial.suggest_uniform("x", -10, 10)
    return (x - 2) ** 2

with dask.distributed.Client() as client:
    # Create a study using Dask-compatible storage
    study = optuna.create_study(storage=dask_optuna.DaskStorage())
    # Optimize in parallel on your Dask cluster
    dask_optuna.optimize(study, objective, n_trials=100)
    print(f"best_params = {study.best_params}")

NOTE: Currently this only works with synchronous Clients, but it's a start.

Longer-term, since Optuna uses joblib under the hood, one could use optuna.Study.optimize inside a joblib.parallel_backend("dask") context manager to farm out trials to a Dask cluster. However, today this results in non-pickleable objects being joblib.delayed, so the joblib approach doesn't work across multiple processes.