aimhubio / aim

Aim 💫 — An easy-to-use & supercharged open-source experiment tracker.
https://aimstack.io
Apache License 2.0
5.21k stars 320 forks source link

Data loaders abort when using multi-processing with a remote Aim repo #2540

Open andychisholm opened 1 year ago

andychisholm commented 1 year ago

🐛 Bug

We're seeing a non-deterministic error which occurs during a torch lightning train when we adopt a remote AIM repo for logging (i.e. setting repo="aim://our-aim-server:53800/" when initializing a aim.pytorch_lightning.AimLogger.

This only happens when switching from a local AimLogger to remote repo, with no other changes to the codebase.

Mitigations

It's difficult to see how Aim is involved at all in the data loader pipeline to produce a relationship like this, but this is what we can observe.

Error Detail

During the first epoch we typically see one to many:

[mutex.cc : 2374] RAW: Check w->waitp->cond == nullptr failed: Mutex::Fer while waiting on Condition

Immediately followed by data loader abort stack traces, e.g:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1120 in   │
│ _try_get_data                                                                                    │
│                                                                                                  │
│   1117 │   │   # Returns a 2-tuple:                                                              │
│   1118 │   │   #   (bool: whether successfully get data, any: data if successful else None)      │
│   1119 │   │   try:                                                                              │
│ ❱ 1120 │   │   │   data = self._data_queue.get(timeout=timeout)                                  │
│   1121 │   │   │   return (True, data)                                                           │
│   1122 │   │   except Exception as e:                                                            │
│   1123 │   │   │   # At timeout and error, we manually check whether any worker has              │
│                                                                                                  │
│ /usr/lib/python3.8/multiprocessing/queues.py:107 in get                                          │
│                                                                                                  │
│   104 │   │   │   try:                                                                           │
│   105 │   │   │   │   if block:                                                                  │
│   106 │   │   │   │   │   timeout = deadline - time.monotonic()                                  │
│ ❱ 107 │   │   │   │   │   if not self._poll(timeout):                                            │
│   108 │   │   │   │   │   │   raise Empty                                                        │
│   109 │   │   │   │   elif not self._poll():                                                     │
│   110 │   │   │   │   │   raise Empty                                                            │
│                                                                                                  │
│ /usr/lib/python3.8/multiprocessing/connection.py:257 in poll                                     │
│                                                                                                  │
│   254 │   │   """Whether there is any input available to be read"""                              │
│   255 │   │   self._check_closed()                                                               │
│   256 │   │   self._check_readable()                                                             │
│ ❱ 257 │   │   return self._poll(timeout)                                                         │
│   258 │                                                                                          │
│   259 │   def __enter__(self):                                                                   │
│   260 │   │   return self                                                                        │
│                                                                                                  │
│ /usr/lib/python3.8/multiprocessing/connection.py:424 in _poll                                    │
│                                                                                                  │
│   421 │   │   return self._recv(size)                                                            │
│   422 │                                                                                          │
│   423 │   def _poll(self, timeout):                                                              │
│ ❱ 424 │   │   r = wait([self], timeout)                                                          │
│   425 │   │   return bool(r)                                                                     │
│   426                                                                                            │
│   427                                                                                            │
│                                                                                                  │
│ /usr/lib/python3.8/multiprocessing/connection.py:931 in wait                                     │
│                                                                                                  │
│   928 │   │   │   │   deadline = time.monotonic() + timeout                                      │
│   929 │   │   │                                                                                  │
│   930 │   │   │   while True:                                                                    │
│ ❱ 931 │   │   │   │   ready = selector.select(timeout)                                           │
│   932 │   │   │   │   if ready:                                                                  │
│   933 │   │   │   │   │   return [key.fileobj for (key, events) in ready]                        │
│   934 │   │   │   │   else:                                                                      │
│                                                                                                  │
│ /usr/lib/python3.8/selectors.py:415 in select                                                    │
│                                                                                                  │
│   412 │   │   │   timeout = math.ceil(timeout * 1e3)                                             │
│   413 │   │   ready = []                                                                         │
│   414 │   │   try:                                                                               │
│ ❱ 415 │   │   │   fd_event_list = self._selector.poll(timeout)                                   │
│   416 │   │   except InterruptedError:                                                           │
│   417 │   │   │   return ready                                                                   │
│   418 │   │   for fd, event in fd_event_list:                                                    │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling. │
│ py:66 in handler                                                                                 │
│                                                                                                  │
│   63 │   def handler(signum, frame):                                                             │
│   64 │   │   # This following call uses `waitid` with WNOHANG from C side. Therefore,            │
│   65 │   │   # Python can still get and update the process status successfully.                  │
│ ❱ 66 │   │   _error_if_any_worker_fails()                                                        │
│   67 │   │   if previous_handler is not None:                                                    │
│   68 │   │   │   assert callable(previous_handler)                                               │
│   69 │   │   │   previous_handler(signum, frame)                                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: DataLoader worker (pid 353773) is killed by signal: Aborted. 

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:36 in  │
│ _call_and_handle_interrupt                                                                       │
│                                                                                                  │
│   33 │   """                                                                                     │
│   34 │   try:                                                                                    │
│   35 │   │   if trainer.strategy.launcher is not None:                                           │
│ ❱ 36 │   │   │   return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer,     │
│   37 │   │   else:                                                                               │
│   38 │   │   │   return trainer_fn(*args, **kwargs)                                              │
│   39                                                                                             │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/s │
│ ubprocess_script.py:90 in launch                                                                 │
│                                                                                                  │
│    87 │   │   """                                                                                │
│    88 │   │   if not self.cluster_environment.creates_processes_externally:                      │
│    89 │   │   │   self._call_children_scripts()                                                  │
│ ❱  90 │   │   return function(*args, **kwargs)                                                   │
│    91 │                                                                                          │
│    92 │   def _call_children_scripts(self) -> None:                                              │
│    93 │   │   # bookkeeping of spawned processes                                                 │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:645 │
│ in _fit_impl                                                                                     │
│                                                                                                  │
│    642 │   │   │   model_provided=True,                                                          │
│    643 │   │   │   model_connected=self.lightning_module is not None,                            │
│    644 │   │   )                                                                                 │
│ ❱  645 │   │   self._run(model, ckpt_path=self.ckpt_path)                                        │
│    646 │   │                                                                                     │
│    647 │   │   assert self.state.stopped                                                         │
│    648 │   │   self.training = False                                                             │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:109 │
│ 8 in _run                                                                                        │
│                                                                                                  │
│   1095 │   │                                                                                     │
│   1096 │   │   self._checkpoint_connector.resume_end()                                           │
│   1097 │   │                                                                                     │
│ ❱ 1098 │   │   results = self._run_stage()                                                       │
│   1099 │   │                                                                                     │
│   1100 │   │   log.detail(f"{self.__class__.__name__}: trainer tearing down")                    │
│   1101 │   │   self._teardown()                                                                  │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:117 │
│ 7 in _run_stage                                                                                  │
│                                                                                                  │
│   1174 │   │   │   return self._run_evaluate()                                                   │
│   1175 │   │   if self.predicting:                                                               │
│   1176 │   │   │   return self._run_predict()                                                    │
│ ❱ 1177 │   │   self._run_train()                                                                 │
│   1178 │                                                                                         │
│   1179 │   def _pre_training_routine(self) -> None:                                              │
│   1180 │   │   # wait for all to join if on distributed                                          │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:120 │
│ 0 in _run_train                                                                                  │
│                                                                                                  │
│   1197 │   │   self.fit_loop.trainer = self                                                      │
│   1198 │   │                                                                                     │
│   1199 │   │   with torch.autograd.set_detect_anomaly(self._detect_anomaly):                     │
│ ❱ 1200 │   │   │   self.fit_loop.run()                                                           │
│   1201 │                                                                                         │
│   1202 │   def _run_evaluate(self) -> _EVALUATE_OUTPUT:                                          │
│   1203 │   │   assert self.evaluating                                                            │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py:199 in   │
│ run                                                                                              │
│                                                                                                  │
│   196 │   │   while not self.done:                                                               │
│   197 │   │   │   try:                                                                           │
│   198 │   │   │   │   self.on_advance_start(*args, **kwargs)                                     │
│ ❱ 199 │   │   │   │   self.advance(*args, **kwargs)                                              │
│   200 │   │   │   │   self.on_advance_end()                                                      │
│   201 │   │   │   │   self._restarting = False                                                   │
│   202 │   │   │   except StopIteration:                                                          │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:267  │
│ in advance                                                                                       │
│                                                                                                  │
│   264 │   │   assert self._data_fetcher is not None                                              │
│   265 │   │   self._data_fetcher.setup(dataloader, batch_to_device=batch_to_device)              │
│   266 │   │   with self.trainer.profiler.profile("run_training_epoch"):                          │
│ ❱ 267 │   │   │   self._outputs = self.epoch_loop.run(self._data_fetcher)                        │
│   268 │                                                                                          │
│   269 │   def on_advance_end(self) -> None:                                                      │
│   270 │   │   # inform logger the batch loop has finished                                        │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py:199 in   │
│ run                                                                                              │
│                                                                                                  │
│   196 │   │   while not self.done:                                                               │
│   197 │   │   │   try:                                                                           │
│   198 │   │   │   │   self.on_advance_start(*args, **kwargs)                                     │
│ ❱ 199 │   │   │   │   self.advance(*args, **kwargs)                                              │
│   200 │   │   │   │   self.on_advance_end()                                                      │
│   201 │   │   │   │   self._restarting = False                                                   │
│   202 │   │   │   except StopIteration:                                                          │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_e │
│ poch_loop.py:188 in advance                                                                      │
│                                                                                                  │
│   185 │   │                                                                                      │
│   186 │   │   if not isinstance(data_fetcher, DataLoaderIterDataFetcher):                        │
│   187 │   │   │   batch_idx = self.batch_idx + 1                                                 │
│ ❱ 188 │   │   │   batch = next(data_fetcher)                                                     │
│   189 │   │   else:                                                                              │
│   190 │   │   │   batch_idx, batch = next(data_fetcher)                                          │
│   191 │   │   self.batch_progress.is_last_batch = data_fetcher.done                              │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py: │
│ 184 in __next__                                                                                  │
│                                                                                                  │
│   181 │   │   return self                                                                        │
│   182 │                                                                                          │
│   183 │   def __next__(self) -> Any:                                                             │
│ ❱ 184 │   │   return self.fetching_function()                                                    │
│   185 │                                                                                          │
│   186 │   def reset(self) -> None:                                                               │
│   187 │   │   self.fetched = 0                                                                   │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py: │
│ 265 in fetching_function                                                                         │
│                                                                                                  │
│   262 │   │   elif not self.done:                                                                │
│   263 │   │   │   # this will run only when no pre-fetching was done.                            │
│   264 │   │   │   try:                                                                           │
│ ❱ 265 │   │   │   │   self._fetch_next_batch(self.dataloader_iter)                               │
│   266 │   │   │   │   # consume the batch we just fetched                                        │
│   267 │   │   │   │   batch = self.batches.pop(0)                                                │
│   268 │   │   │   except StopIteration as e:                                                     │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py: │
│ 280 in _fetch_next_batch                                                                         │
│                                                                                                  │
│   277 │   def _fetch_next_batch(self, iterator: Iterator) -> None:                               │
│   278 │   │   start_output = self.on_fetch_start()                                               │
│   279 │   │   try:                                                                               │
│ ❱ 280 │   │   │   batch = next(iterator)                                                         │
│   281 │   │   except StopIteration as e:                                                         │
│   282 │   │   │   self._stop_profiler()                                                          │
│   283 │   │   │   raise e                                                                        │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/supporters.py: │
│ 568 in __next__                                                                                  │
│                                                                                                  │
│   565 │   │   Returns:                                                                           │
│   566 │   │   │   a collections of batch data                                                    │
│   567 │   │   """                                                                                │
│ ❱ 568 │   │   return self.request_next_batch(self.loader_iters)                                  │
│   569 │                                                                                          │
│   570 │   @staticmethod                                                                          │
│   571 │   def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any:       │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/supporters.py: │
│ 580 in request_next_batch                                                                        │
│                                                                                                  │
│   577 │   │   Returns                                                                            │
│   578 │   │   │   Any: a collections of batch data                                               │
│   579 │   │   """                                                                                │
│ ❱ 580 │   │   return apply_to_collection(loader_iters, Iterator, next)                           │
│   581 │                                                                                          │
│   582 │   @staticmethod                                                                          │
│   583 │   def create_loader_iters(                                                               │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/lightning_utilities/core/apply_func.py:4 │
│ 7 in apply_to_collection                                                                         │
│                                                                                                  │
│    44 │   """                                                                                    │
│    45 │   # Breaking condition                                                                   │
│    46 │   if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dt   │
│ ❱  47 │   │   return function(data, *args, **kwargs)                                             │
│    48 │                                                                                          │
│    49 │   elem_type = type(data)                                                                 │
│    50                                                                                            │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:628 in    │
│ __next__                                                                                         │
│                                                                                                  │
│    625 │   │   │   if self._sampler_iter is None:                                                │
│    626 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   │
│    627 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
│ ❱  628 │   │   │   data = self._next_data()                                                      │
│    629 │   │   │   self._num_yielded += 1                                                        │
│    630 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \                          │
│    631 │   │   │   │   │   self._IterableDataset_len_called is not None and \                    │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1316 in   │
│ _next_data                                                                                       │
│                                                                                                  │
│   1313 │   │   │   │   return self._process_data(data)                                           │
│   1314 │   │   │                                                                                 │
│   1315 │   │   │   assert not self._shutdown and self._tasks_outstanding > 0                     │
│ ❱ 1316 │   │   │   idx, data = self._get_data()                                                  │
│   1317 │   │   │   self._tasks_outstanding -= 1                                                  │
│   1318 │   │   │   if self._dataset_kind == _DatasetKind.Iterable:                               │
│   1319 │   │   │   │   # Check for _IterableDatasetStopIteration                                 │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1282 in   │
│ _get_data                                                                                        │
│                                                                                                  │
│   1279 │   │   │   # need to call `.task_done()` because we don't use `.join()`.                 │
│   1280 │   │   else:                                                                             │
│   1281 │   │   │   while True:                                                                   │
│ ❱ 1282 │   │   │   │   success, data = self._try_get_data()                                      │
│   1283 │   │   │   │   if success:                                                               │
│   1284 │   │   │   │   │   return data                                                           │
│   1285                                                                                           │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1133 in   │
│ _try_get_data                                                                                    │
│                                                                                                  │
│   1130 │   │   │   │   │   self._mark_worker_as_unavailable(worker_id)                           │
│   1131 │   │   │   if len(failed_workers) > 0:                                                   │
│   1132 │   │   │   │   pids_str = ', '.join(str(w.pid) for w in failed_workers)                  │
│ ❱ 1133 │   │   │   │   raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.f  │
│   1134 │   │   │   if isinstance(e, queue.Empty):                                                │
│   1135 │   │   │   │   return (False, None)                                                      │
│   1136 │   │   │   import tempfile                                                               │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: DataLoader worker (pid(s) 353773, 354505) exited unexpectedly

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /<path>/<project>/pipelines/modelling/cli.py:134 in train                              │
│                                                                                                  │
│   131 │   │   strategy=DDPStrategy(find_unused_parameters=False),                                │
│   133 │   )                                                                                      │
│ ❱ 134 │   trainer.fit(model=experiment, datamodule=data_module)                                  │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:603 │
│ in fit                                                                                           │
│                                                                                                  │
│    600 │   │   if not isinstance(model, pl.LightningModule):                                     │
│    601 │   │   │   raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.  │
│    602 │   │   self.strategy._lightning_module = model                                           │
│ ❱  603 │   │   call._call_and_handle_interrupt(                                                  │
│    604 │   │   │   self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule,  │
│    605 │   │   )                                                                                 │
│    606                                                                                           │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:59 in  │
│ _call_and_handle_interrupt                                                                       │
│                                                                                                  │
│   56 │   │   trainer.state.status = TrainerStatus.INTERRUPTED                                    │
│   57 │   │   if _distributed_available() and trainer.world_size > 1:                             │
│   58 │   │   │   # try syncing remaining processes, kill otherwise                               │
│ ❱ 59 │   │   │   trainer.strategy.reconciliate_processes(traceback.format_exc())                 │
│   60 │   │   trainer._call_callback_hooks("on_exception", exception)                             │
│   61 │   │   for logger in trainer.loggers:                                                      │
│   62 │   │   │   logger.finalize("failed")                                                       │
│                                                                                                  │
│ /<path>/venv/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py:461  │
│ in reconciliate_processes                                                                        │
│                                                                                                  │
│   458 │   │   │   if pid != os.getpid():                                                         │
│   459 │   │   │   │   os.kill(pid, signal.SIGKILL)                                               │
│   460 │   │   shutil.rmtree(sync_dir)                                                            │
│ ❱ 461 │   │   raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank   │
│   462 │                                                                                          │
│   463 │   def teardown(self) -> None:                                                            │
│   464 │   │   log.detail(f"{self.__class__.__name__}: tearing down strategy")                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
DeadlockDetectedException: DeadLock detected from rank: 0 
 Traceback (most recent call last):
  File "/<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1120, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
    if not self._poll(timeout):
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
    r = wait([self], timeout)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
  File "/<path>/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 353773) is killed by signal: Aborted. 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 90, in launch
    return function(*args, **kwargs)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run
    results = self._run_stage()
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage
    self._run_train()
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1200, in _run_train
    self.fit_loop.run()
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 188, in advance
    batch = next(data_fetcher)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in __next__
    return self.fetching_function()
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 265, in fetching_function
    self._fetch_next_batch(self.dataloader_iter)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch
    batch = next(iterator)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/supporters.py", line 568, in __next__
    return self.request_next_batch(self.loader_iters)
  File "/<path>/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/supporters.py", line 580, in request_next_batch
    return apply_to_collection(loader_iters, Iterator, next)
  File "/<path>/venv/lib/python3.8/site-packages/lightning_utilities/core/apply_func.py", line 47, in apply_to_collection
    return function(data, *args, **kwargs)
  File "/<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1316, in _next_data
    idx, data = self._get_data()
  File "/<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1282, in _get_data
    success, data = self._try_get_data()
  File "/<path>/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 353773, 354505) exited unexpectedly

To reproduce

Unable to provide a minimal reproduction at this stage.

Appreciate this is going to be incredibly difficult to debug! Just hoping someone's seen something like this before.

Expected behavior

Aim logger initialisation should not cause torch data loader deadlocks.

Environment

aim==3.15.2
aim-ui==3.15.2
aimrecords==0.0.7
aimrocks==0.2.1
grpcio==1.51.1
alberttorosyan commented 1 year ago

Hey @andychisholm! Thanks for submitting the issue with such details. Really appreciate that 🙌 I'll try to reproduce it, please do expect some questions during that process.

andychisholm commented 1 year ago

@alberttorosyan any thoughts on this one? Even a potentially fruitful direction to explore when debugging would be useful

alberttorosyan commented 1 year ago

@andychisholm, I don't have good evidence on what's happening yet. The only possible thing which comes to my mind is following: AimLogger instance is available on rankNs. The intended use of the logger is to run only on rank0.

I'll continue looking into this. Any additional information would be a huge help!

lminer commented 1 year ago

I'm seeing the same issue. My aim repo is also remote. Seems related to this: https://github.com/Lightning-AI/lightning/issues/8821

andychisholm commented 1 year ago

Just to follow up on this one, I think it's to do with a lack of forking support for the GRPC client. Regardless of whether the aim loggers are used in sub-processes they blow up the data loaders in various non-deterministic ways.

For example, if you do a DDP train with multiple GPUs and multiple dataloader workers per GPU this occurs, but if you switch the start method from the default fork to spawn then it's mitigated.

vikigenius commented 1 year ago

I can also confirm this issue. It is related to https://github.com/aimhubio/aim/issues/1297 I am guessing.