nttcslab-sp / mamba-diarization

Official repository for Mamba-based Segmentation Model for Speaker Diarization
Other
24 stars 3 forks source link

pytorch_lightning.Trainer.fit() error #6

Closed sipercai closed 2 days ago

sipercai commented 3 days ago

When I was running the 1_training_from_scratch script, there was no problem with trainer.validate(mamba_diar). However, when the trainer.fit(mamba_diar) statement was running, I encountered the following error message, which showed that there was no key of "metadata" in the task for *[self.prepared_data["metadata"][key] for key in balance]. Have you ever encountered a similar problem? Do you have any suggestions for solving it?

Details

{ "name": "KeyError", "message": "Caught KeyError in DataLoader worker process 0. Original Traceback (most recent call last): File \"/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py\", line 309, in _worker_loop data = fetcher.fetch(index) # type: ignore[possibly-undefined] ^^^^^^^^^^^^^^^^^^^^ File \"/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py\", line 33, in fetch data.append(next(self.dataset_iter)) ^^^^^^^^^^^^^^^^^^^^^^^ File \"/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py\", line 187, in train__iter__ *[self.prepared_data[\"metadata\"][key] for key in balance] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File \"/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py\", line 187, in *[self.prepared_data[\"metadata\"][key] for key in balance] ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ KeyError: 'metadata' ", "stack": "--------------------------------------------------------------------------- KeyError Traceback (most recent call last) Cell In[17], line 1 ----> 1 trainer.fit(mamba_diar) File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 536 self.state.status = TrainerStatus.RUNNING 537 self.training = True --> 538 call._call_and_handle_interrupt( 539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 540 ) File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs) 45 if trainer.strategy.launcher is not None: 46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) ---> 47 return trainer_fn(*args, **kwargs) 49 except _TunerExitException: 50 _call_teardown_hook(trainer) File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 567 assert self.state.fn is not None 568 ckpt_path = self._checkpoint_connector._select_ckpt_path( 569 self.state.fn, 570 ckpt_path, 571 model_provided=True, 572 model_connected=self.lightning_module is not None, 573 ) --> 574 self._run(model, ckpt_path=ckpt_path) 576 assert self.state.stopped 577 self.training = False File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path) 976 self._signal_connector.register_signal_handlers() 978 # ---------------------------- 979 # RUN THE TRAINER 980 # ---------------------------- --> 981 results = self._run_stage() 983 # ---------------------------- 984 # POST-Training CLEAN UP 985 # ---------------------------- 986 log.debug(f\"{self.__class__.__name__}: trainer tearing down\") File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1025, in Trainer._run_stage(self) 1023 self._run_sanity_check() 1024 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1025 self.fit_loop.run() 1026 return None 1027 raise RuntimeError(f\"Unexpected state {self.state}\") File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self) 203 try: 204 self.on_advance_start() --> 205 self.advance() 206 self.on_advance_end() 207 self._restarting = False File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self) 361 with self.trainer.profiler.profile(\"run_training_epoch\"): 362 assert self._data_fetcher is not None --> 363 self.epoch_loop.run(self._data_fetcher) File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher) 138 while not self.done: 139 try: --> 140 self.advance(data_fetcher) 141 self.on_advance_end(data_fetcher) 142 self._restarting = False File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:212, in _TrainingEpochLoop.advance(self, data_fetcher) 210 else: 211 dataloader_iter = None --> 212 batch, _, __ = next(data_fetcher) 213 # TODO: we should instead use the batch_idx returned by the fetcher, however, that will require saving the 214 # fetcher state so that the batch_idx is correct after restarting 215 batch_idx = self.batch_idx + 1 File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py:133, in _PrefetchDataFetcher.__next__(self) 130 self.done = not self.batches 131 elif not self.done: 132 # this will run only when no pre-fetching was done. --> 133 batch = super().__next__() 134 else: 135 # the iterator is empty 136 raise StopIteration File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py:60, in _DataFetcher.__next__(self) 58 self._start_profiler() 59 try: ---> 60 batch = next(self.iterator) 61 except StopIteration: 62 self.done = True File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py:341, in CombinedLoader.__next__(self) 339 def __next__(self) -> _ITERATOR_RETURN: 340 assert self._iterator is not None --> 341 out = next(self._iterator) 342 if isinstance(self._iterator, _Sequential): 343 return out File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py:78, in _MaxSizeCycle.__next__(self) 76 for i in range(n): 77 try: ---> 78 out[i] = next(self.iterators[i]) 79 except StopIteration: 80 self._consumed[i] = True File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self) 627 if self._sampler_iter is None: 628 # TODO(https://github.com/pytorch/pytorch/issues/76750) 629 self._reset() # type: ignore[call-arg] --> 630 data = self._next_data() 631 self._num_yielded += 1 632 if self._dataset_kind == _DatasetKind.Iterable and \\ 633 self._IterableDataset_len_called is not None and \\ 634 self._num_yielded > self._IterableDataset_len_called: File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1344, in _MultiProcessingDataLoaderIter._next_data(self) 1342 else: 1343 del self._task_info[idx] -> 1344 return self._process_data(data) File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1370, in _MultiProcessingDataLoaderIter._process_data(self, data) 1368 self._try_put_index() 1369 if isinstance(data, ExceptionWrapper): -> 1370 data.reraise() 1371 return data File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/_utils.py:706, in ExceptionWrapper.reraise(self) 702 except TypeError: 703 # If the exception takes multiple arguments, don't try to 704 # instantiate since we don't know how to 705 raise RuntimeError(msg) from None --> 706 raise exception KeyError: Caught KeyError in DataLoader worker process 0. Original Traceback (most recent call last): File \"/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py\", line 309, in _worker_loop data = fetcher.fetch(index) # type: ignore[possibly-undefined] ^^^^^^^^^^^^^^^^^^^^ File \"/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py\", line 33, in fetch data.append(next(self.dataset_iter)) ^^^^^^^^^^^^^^^^^^^^^^^ File \"/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py\", line 187, in train__iter__ *[self.prepared_data[\"metadata\"][key] for key in balance] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File \"/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py\", line 187, in *[self.prepared_data[\"metadata\"][key] for key in balance] ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ KeyError: 'metadata' " }

FrenchKrab commented 2 days ago

Thanks for the issue and detailed logs ! I pushed a PR that should fix this issue if you can install pyannote from git.

You can also ignore this issue by not using balance (do not pass balance to the task/pass None), maybe this is OK for your case. The balance parameter is there to uniformly sample according to a criterion (in the case of this paper, uniformly sample from each dataset).

FrenchKrab commented 2 days ago

Actually I realized there has been some breaking changes in latest pyannote versions while, so I'm not 100% confident this PR will work right away, especially with this repository (and I dont have the time to really test things for now). But if you apply the commit from this PR to pyannote 3.1 (git cherry-pick?) it should work. Or again, if you are OK not using it, it's probably the easiest option.

Sorry for the inconvenience !

sipercai commented 2 days ago

Thank you for taking the time to reply amidst your busy schedule.

I have tried your two suggestions. Firstly, removing balance=['database'] from the task worked. After removing it, the training was successfully carried out! Your pull request (PR) was also effective. I don't really know how to quickly apply your PR. The steps I tried were to copy down the corresponding modifications of your task.py and mixins.py in the site-packages/pyannote/audio of anaconda3/envs in the conda environment. After running it again, the normal training could also be achieved!

Thank you very much for your reply! Best regards!