jrbourbeau / dask-optuna

Scale Optuna with Dask
https://jrbourbeau.github.io/dask-optuna/
MIT License
35 stars 6 forks source link

Dask distributed: ModuleNotFoundError: No module named 'dask_optuna' #24

Open gcaria opened 2 years ago

gcaria commented 2 years ago

My naive interpretation for this issue is that the dask_optuna package can't be sent to the workers, please let me know if I'm doing anything wrong here.


from dask_gateway import Gateway
gateway = Gateway()
options = gateway.cluster_options()
cluster = gateway.new_cluster(cluster_options=options)
client = cluster.get_client()
​
def objective(trial):
    n_estimators = trial.suggest_int('n_estimators', 100, 500, step=5)

    model = XGBRegressor(eval_metric="rmse", n_estimators=n_estimators)

    scores = cross_val_score(model, 
                             X_train, 
                             y_train, 
                             cv=KFold(n_splits=5,
                                      shuffle=True,
                                      random_state=42),
                             scoring="neg_root_mean_squared_error"
                            )
    return scores.mean()
​
storage = dask_optuna.DaskStorage()
​
study = optuna.create_study(
    direction="maximize",
    storage=storage,
)
​
with joblib.parallel_backend("dask"):
    study.optimize(objective, n_trials=200, n_jobs=-1)
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
/tmp/ipykernel_465/3864023500.py in <module>
     23 
     24 # Create an Optuna study using a Dask-compatible Optuna storage class
---> 25 storage = dask_optuna.DaskStorage()
     26 
     27 study = optuna.create_study(

/srv/conda/envs/notebook/lib/python3.9/site-packages/dask_optuna/storage.py in __init__(self, storage, name, client)
    323             self._started = asyncio.ensure_future(_register())
    324         else:
--> 325             self.client.run_on_scheduler(
    326                 register_with_scheduler, storage=storage, name=self.name
    327             )

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/client.py in run_on_scheduler(self, function, *args, **kwargs)
   2419         Client.start_ipython_scheduler : Start an IPython session on scheduler
   2420         """
-> 2421         return self.sync(self._run_on_scheduler, function, *args, **kwargs)
   2422 
   2423     async def _run(

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    866             return future
    867         else:
--> 868             return sync(
    869                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    870             )

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    330     if error[0]:
    331         typ, exc, tb = error[0]
--> 332         raise exc.with_traceback(tb)
    333     else:
    334         return result[0]

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/utils.py in f()
    313             if callback_timeout is not None:
    314                 future = asyncio.wait_for(future, callback_timeout)
--> 315             result[0] = yield future
    316         except Exception:
    317             error[0] = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.9/site-packages/tornado/gen.py in run(self)
    760 
    761                     try:
--> 762                         value = future.result()
    763                     except Exception:
    764                         exc_info = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/client.py in _run_on_scheduler(self, function, wait, *args, **kwargs)
   2378 
   2379     async def _run_on_scheduler(self, function, *args, wait=True, **kwargs):
-> 2380         response = await self.scheduler.run_function(
   2381             function=dumps(function, protocol=4),
   2382             args=dumps(args, protocol=4),

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/core.py in send_recv_from_rpc(**kwargs)
    893             name, comm.name = comm.name, "ConnectionPool." + key
    894             try:
--> 895                 result = await send_recv(comm=comm, op=key, **kwargs)
    896             finally:
    897                 self.pool.reuse(self.addr, comm)

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/core.py in send_recv(comm, reply, serializers, deserializers, **kwargs)
    686         if comm.deserialize:
    687             typ, exc, tb = clean_exception(**response)
--> 688             raise exc.with_traceback(tb)
    689         else:
    690             raise Exception(response["exception_text"])

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/core.py in handle_comm()
    528                             result = asyncio.ensure_future(result)
    529                             self._ongoing_coroutines.add(result)
--> 530                             result = await result
    531                     except (CommClosedError, CancelledError):
    532                         if self.status in (Status.running, Status.paused):

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/worker.py in run()
   4528 async def run(server, comm, function, args=(), kwargs=None, is_coro=None, wait=True):
   4529     kwargs = kwargs or {}
-> 4530     function = pickle.loads(function)
   4531     if is_coro is None:
   4532         is_coro = iscoroutinefunction(function)

/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/protocol/pickle.py in loads()
     73             return pickle.loads(x, buffers=buffers)
     74         else:
---> 75             return pickle.loads(x)
     76     except Exception:
     77         logger.info("Failed to deserialize %s", x[:10000], exc_info=True)

ModuleNotFoundError: No module named 'dask_optuna'