RetroCirce / HTS-Audio-Transformer

The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"
https://arxiv.org/abs/2202.00874
MIT License
341 stars 62 forks source link

cannot pickle 'module' object when running the htsat_esc_training #27

Closed kremHabashy closed 1 year ago

kremHabashy commented 1 year ago

once trainier.fit(model, audio set_data) is called, the error below is output. Any help on the matter would be greatly appreciated!!


TypeError Traceback (most recent call last) Cell In [10], line 3 1 # Training the model 2 # You can set different fold index by setting 'esc_fold' to any number from 0-4 in esc_config.py ----> 3 trainer.fit(model, audioset_data)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path) 735 rank_zero_deprecation( 736 "trainer.fit(train_dataloader) is deprecated in v1.4 and will be removed in v1.6." 737 " Use trainer.fit(train_dataloaders) instead. HINT: added 's'" 738 ) 739 train_dataloaders = train_dataloader --> 740 self._call_and_handle_interrupt( 741 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 742 )

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, args, kwargs) 675 r""" 676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) 677 as all errors should funnel through them (...) 682 kwargs: keyword arguments to be passed to trainer_fn 683 """ 684 try: --> 685 return trainer_fn(args, **kwargs) 686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 687 except KeyboardInterrupt as exception:

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 775 # TODO: ckpt_path only in v1.7 776 ckpt_path = ckpt_path or self.resume_from_checkpoint --> 777 self._run(model, ckpt_path=ckpt_path) 779 assert self.state.stopped 780 self.training = False

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1199, in Trainer._run(self, model, ckpt_path) 1196 self.checkpoint_connector.resume_end() 1198 # dispatch start_training or start_evaluating or start_predicting -> 1199 self._dispatch() 1201 # plugin will finalized fitting (e.g. ddp_spawn will load trained model) 1202 self._post_dispatch()

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1279, in Trainer._dispatch(self) 1277 self.training_type_plugin.start_predicting(self) 1278 else: -> 1279 self.training_type_plugin.start_training(self)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:202, in TrainingTypePlugin.start_training(self, trainer) 200 def start_training(self, trainer: "pl.Trainer") -> None: 201 # double dispatch to initiate the training loop --> 202 self._results = trainer.run_stage()

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1289, in Trainer.run_stage(self) 1287 if self.predicting: 1288 return self._run_predict() -> 1289 return self._run_train()

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1319, in Trainer._run_train(self) 1317 self.fit_loop.trainer = self 1318 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1319 self.fit_loop.run()

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, kwargs) 143 try: 144 self.on_advance_start(*args, *kwargs) --> 145 self.advance(args, kwargs) 146 self.on_advance_end() 147 self.restarting = False

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:234, in FitLoop.advance(self) 231 data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader) 233 with self.trainer.profiler.profile("run_training_epoch"): --> 234 self.epoch_loop.run(data_fetcher) 236 # the global step is manually decreased here due to backwards compatibility with existing loggers 237 # as they expect that the same step is used when logging epoch end metrics even when the batch loop has 238 # finished. this means the attribute does not exactly track the number of optimizer steps applied. 239 # TODO(@carmocca): deprecate and rename so users don't get confused 240 self.global_step -= 1

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:140, in Loop.run(self, *args, *kwargs) 136 return self.on_skip() 138 self.reset() --> 140 self.on_run_start(args, **kwargs) 142 while not self.done: 143 try:

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:141, in TrainingEpochLoop.on_run_start(self, data_fetcher, **kwargs) 138 self.trainer.fit_loop.epoch_progress.increment_started() 140 self._reload_dataloader_state_dict(data_fetcher) --> 141 self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py:121, in _update_dataloader_iter(data_fetcher, batch_idx) 118 """Attach the dataloader.""" 119 if not isinstance(data_fetcher, DataLoaderIterDataFetcher): 120 # restore iteration --> 121 dataloader_iter = enumerate(data_fetcher, batch_idx) 122 else: 123 dataloader_iter = iter(data_fetcher)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py:198, in AbstractDataFetcher.iter(self) 196 self.reset() 197 self.dataloader_iter = iter(self.dataloader) --> 198 self._apply_patch() 199 self.prefetching(self.prefetch_batches) 200 return self

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py:133, in AbstractDataFetcher._apply_patch(self) 130 loader._lightning_fetcher = self 131 patch_dataloader_iterator(loader, iterator, self) --> 133 apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py:181, in AbstractDataFetcher.loader_iters(self) 178 raise MisconfigurationException("The dataloader_iter isn't available outside the iter context.") 180 if isinstance(self.dataloader, CombinedLoader): --> 181 loader_iters = self.dataloader_iter.loader_iters 182 else: 183 loader_iters = [self.dataloader_iter]

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/supporters.py:537, in CombinedLoaderIterator.loader_iters(self) 535 """Get the _loader_iters and create one if it is None.""" 536 if self._loader_iters is None: --> 537 self._loader_iters = self.create_loader_iters(self.loaders) 539 return self._loader_iters

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/supporters.py:577, in CombinedLoaderIterator.create_loader_iters(loaders) 568 """Create and return a collection of iterators from loaders. 569 570 Args: (...) 574 a collections of iterators 575 """ 576 # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences --> 577 return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/pytorch_lightning/utilities/apply_func.py:95, in apply_to_collection(data, dtype, function, wrong_dtype, include_none, *args, *kwargs) 93 # Breaking condition 94 if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): ---> 95 return function(data, args, **kwargs) 97 elem_type = type(data) 99 # Recursively apply to collection items

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/torch/utils/data/dataloader.py:435, in DataLoader.iter(self) 433 return self._iterator 434 else: --> 435 return self._get_iterator()

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/torch/utils/data/dataloader.py:381, in DataLoader._get_iterator(self) 379 else: 380 self.check_worker_number_rationality() --> 381 return _MultiProcessingDataLoaderIter(self)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1034, in _MultiProcessingDataLoaderIter.init(self, loader) 1027 w.daemon = True 1028 # NB: Process.start() actually take some time as it needs to 1029 # start a process and pass the arguments over via a pipe. 1030 # Therefore, we only add a worker to self._workers list after 1031 # it started, so that we do not call .join() if program dies 1032 # before it starts, and del tries to join but will get: 1033 # AssertionError: can only join a started process. -> 1034 w.start() 1035 self._index_queues.append(index_queue) 1036 self._workers.append(w)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/multiprocessing/process.py:121, in BaseProcess.start(self) 118 assert not _current_process._config.get('daemon'), \ 119 'daemonic processes are not allowed to have children' 120 _cleanup() --> 121 self._popen = self._Popen(self) 122 self._sentinel = self._popen.sentinel 123 # Avoid a refcycle if the target function holds an indirect 124 # reference to the process object (see bpo-30775)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/multiprocessing/context.py:224, in Process._Popen(process_obj) 222 @staticmethod 223 def _Popen(process_obj): --> 224 return _default_context.get_context().Process._Popen(process_obj)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/multiprocessing/context.py:284, in SpawnProcess._Popen(process_obj) 281 @staticmethod 282 def _Popen(process_obj): 283 from .popen_spawn_posix import Popen --> 284 return Popen(process_obj)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/multiprocessing/popen_spawn_posix.py:32, in Popen.init(self, process_obj) 30 def init(self, process_obj): 31 self._fds = [] ---> 32 super().init(process_obj)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/multiprocessing/popen_fork.py:19, in Popen.init(self, process_obj) 17 self.returncode = None 18 self.finalizer = None ---> 19 self._launch(process_obj)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/multiprocessing/popen_spawn_posix.py:47, in Popen._launch(self, process_obj) 45 try: 46 reduction.dump(prep_data, fp) ---> 47 reduction.dump(process_obj, fp) 48 finally: 49 set_spawning_popen(None)

File ~/opt/anaconda3/envs/htsaudio/lib/python3.8/multiprocessing/reduction.py:60, in dump(obj, file, protocol) 58 def dump(obj, file, protocol=None): 59 '''Replacement for pickle.dump() using ForkingPickler.''' ---> 60 ForkingPickler(file, protocol).dump(obj)

TypeError: cannot pickle 'module' object

RetroCirce commented 1 year ago

Hi, please refer to this: https://github.com/RetroCirce/HTS-Audio-Transformer/issues/21

I think it is because pytorch-lightning is updated to 1.8.0, there are some adaptive issues on previous version.