My naive interpretation for this issue is that the dask_optuna package can't be sent to the workers, please let me know if I'm doing anything wrong here.
from dask_gateway import Gateway
gateway = Gateway()
options = gateway.cluster_options()
cluster = gateway.new_cluster(cluster_options=options)
client = cluster.get_client()
def objective(trial):
n_estimators = trial.suggest_int('n_estimators', 100, 500, step=5)
model = XGBRegressor(eval_metric="rmse", n_estimators=n_estimators)
scores = cross_val_score(model,
X_train,
y_train,
cv=KFold(n_splits=5,
shuffle=True,
random_state=42),
scoring="neg_root_mean_squared_error"
)
return scores.mean()
storage = dask_optuna.DaskStorage()
study = optuna.create_study(
direction="maximize",
storage=storage,
)
with joblib.parallel_backend("dask"):
study.optimize(objective, n_trials=200, n_jobs=-1)
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
/tmp/ipykernel_465/3864023500.py in <module>
23
24 # Create an Optuna study using a Dask-compatible Optuna storage class
---> 25 storage = dask_optuna.DaskStorage()
26
27 study = optuna.create_study(
/srv/conda/envs/notebook/lib/python3.9/site-packages/dask_optuna/storage.py in __init__(self, storage, name, client)
323 self._started = asyncio.ensure_future(_register())
324 else:
--> 325 self.client.run_on_scheduler(
326 register_with_scheduler, storage=storage, name=self.name
327 )
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/client.py in run_on_scheduler(self, function, *args, **kwargs)
2419 Client.start_ipython_scheduler : Start an IPython session on scheduler
2420 """
-> 2421 return self.sync(self._run_on_scheduler, function, *args, **kwargs)
2422
2423 async def _run(
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
866 return future
867 else:
--> 868 return sync(
869 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
870 )
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
330 if error[0]:
331 typ, exc, tb = error[0]
--> 332 raise exc.with_traceback(tb)
333 else:
334 return result[0]
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/utils.py in f()
313 if callback_timeout is not None:
314 future = asyncio.wait_for(future, callback_timeout)
--> 315 result[0] = yield future
316 except Exception:
317 error[0] = sys.exc_info()
/srv/conda/envs/notebook/lib/python3.9/site-packages/tornado/gen.py in run(self)
760
761 try:
--> 762 value = future.result()
763 except Exception:
764 exc_info = sys.exc_info()
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/client.py in _run_on_scheduler(self, function, wait, *args, **kwargs)
2378
2379 async def _run_on_scheduler(self, function, *args, wait=True, **kwargs):
-> 2380 response = await self.scheduler.run_function(
2381 function=dumps(function, protocol=4),
2382 args=dumps(args, protocol=4),
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/core.py in send_recv_from_rpc(**kwargs)
893 name, comm.name = comm.name, "ConnectionPool." + key
894 try:
--> 895 result = await send_recv(comm=comm, op=key, **kwargs)
896 finally:
897 self.pool.reuse(self.addr, comm)
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/core.py in send_recv(comm, reply, serializers, deserializers, **kwargs)
686 if comm.deserialize:
687 typ, exc, tb = clean_exception(**response)
--> 688 raise exc.with_traceback(tb)
689 else:
690 raise Exception(response["exception_text"])
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/core.py in handle_comm()
528 result = asyncio.ensure_future(result)
529 self._ongoing_coroutines.add(result)
--> 530 result = await result
531 except (CommClosedError, CancelledError):
532 if self.status in (Status.running, Status.paused):
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/worker.py in run()
4528 async def run(server, comm, function, args=(), kwargs=None, is_coro=None, wait=True):
4529 kwargs = kwargs or {}
-> 4530 function = pickle.loads(function)
4531 if is_coro is None:
4532 is_coro = iscoroutinefunction(function)
/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/protocol/pickle.py in loads()
73 return pickle.loads(x, buffers=buffers)
74 else:
---> 75 return pickle.loads(x)
76 except Exception:
77 logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
ModuleNotFoundError: No module named 'dask_optuna'
My naive interpretation for this issue is that the dask_optuna package can't be sent to the workers, please let me know if I'm doing anything wrong here.