RetroCirce / HTS-Audio-Transformer

The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"
https://arxiv.org/abs/2202.00874
MIT License
341 stars 62 forks source link

TypeError: cannot pickle 'module' object #21

Closed JonathanFL closed 1 year ago

JonathanFL commented 1 year ago

I am running htsat_esc_training.ipynb and getting this error on my PC.

Python version: 3.9.12 Installed all requirements from requirements.txt. Ran the notebook in VSCode. Not changes to the code.

GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs trainer properties: gpus: 1 max_epochs: 100 auto_lr_find: True accelerator: <pytorch_lightning.accelerators.gpu.GPUAccelerator object at 0x0000016B08410790> num_sanity_val_steps: 0 resume_from_checkpoint: None gradient_clip_val: 1.0

Error:


TypeError Traceback (most recent call last) Cell In [26], line 3 1 # Training the model 2 # You can set different fold index by setting 'esc_fold' to any number from 0-4 in esc_config.py ----> 3 trainer.fit(model, audioset_data)

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path) 735 rank_zero_deprecation( 736 "trainer.fit(train_dataloader) is deprecated in v1.4 and will be removed in v1.6." 737 " Use trainer.fit(train_dataloaders) instead. HINT: added 's'" 738 ) 739 train_dataloaders = train_dataloader --> 740 self._call_and_handle_interrupt( 741 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 742 )

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, args, kwargs) 675 r""" 676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) 677 as all errors should funnel through them (...) 682 kwargs: keyword arguments to be passed to trainer_fn 683 """ 684 try: --> 685 return trainer_fn(args, **kwargs) 686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 687 except KeyboardInterrupt as exception:

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 775 # TODO: ckpt_path only in v1.7 776 ckpt_path = ckpt_path or self.resume_from_checkpoint --> 777 self._run(model, ckpt_path=ckpt_path) 779 assert self.state.stopped 780 self.training = False

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1199, in Trainer._run(self, model, ckpt_path) 1196 self.checkpoint_connector.resume_end() 1198 # dispatch start_training or start_evaluating or start_predicting -> 1199 self._dispatch() 1201 # plugin will finalized fitting (e.g. ddp_spawn will load trained model) 1202 self._post_dispatch()

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1279, in Trainer._dispatch(self) 1277 self.training_type_plugin.start_predicting(self) 1278 else: -> 1279 self.training_type_plugin.start_training(self)

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py:202, in TrainingTypePlugin.start_training(self, trainer) 200 def start_training(self, trainer: "pl.Trainer") -> None: 201 # double dispatch to initiate the training loop --> 202 self._results = trainer.run_stage()

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1289, in Trainer.run_stage(self) 1287 if self.predicting: 1288 return self._run_predict() -> 1289 return self._run_train()

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1319, in Trainer._run_train(self) 1317 self.fit_loop.trainer = self 1318 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1319 self.fit_loop.run()

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\base.py:145, in Loop.run(self, *args, kwargs) 143 try: 144 self.on_advance_start(*args, *kwargs) --> 145 self.advance(args, kwargs) 146 self.on_advance_end() 147 self.restarting = False

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\fit_loop.py:234, in FitLoop.advance(self) 231 data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader) 233 with self.trainer.profiler.profile("run_training_epoch"): --> 234 self.epoch_loop.run(data_fetcher) 236 # the global step is manually decreased here due to backwards compatibility with existing loggers 237 # as they expect that the same step is used when logging epoch end metrics even when the batch loop has 238 # finished. this means the attribute does not exactly track the number of optimizer steps applied. 239 # TODO(@carmocca): deprecate and rename so users don't get confused 240 self.global_step -= 1

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\base.py:140, in Loop.run(self, *args, *kwargs) 136 return self.on_skip() 138 self.reset() --> 140 self.on_run_start(args, **kwargs) 142 while not self.done: 143 try:

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\epoch\training_epoch_loop.py:141, in TrainingEpochLoop.on_run_start(self, data_fetcher, **kwargs) 138 self.trainer.fit_loop.epoch_progress.increment_started() 140 self._reload_dataloader_state_dict(data_fetcher) --> 141 self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\utilities.py:121, in _update_dataloader_iter(data_fetcher, batch_idx) 118 """Attach the dataloader.""" 119 if not isinstance(data_fetcher, DataLoaderIterDataFetcher): 120 # restore iteration --> 121 dataloader_iter = enumerate(data_fetcher, batch_idx) 122 else: 123 dataloader_iter = iter(data_fetcher)

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\utilities\fetching.py:198, in AbstractDataFetcher.iter(self) 196 self.reset() 197 self.dataloader_iter = iter(self.dataloader) --> 198 self._apply_patch() 199 self.prefetching(self.prefetch_batches) 200 return self

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\utilities\fetching.py:133, in AbstractDataFetcher._apply_patch(self) 130 loader._lightning_fetcher = self 131 patch_dataloader_iterator(loader, iterator, self) --> 133 apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn)

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\utilities\fetching.py:181, in AbstractDataFetcher.loader_iters(self) 178 raise MisconfigurationException("The dataloader_iter isn't available outside the iter context.") 180 if isinstance(self.dataloader, CombinedLoader): --> 181 loader_iters = self.dataloader_iter.loader_iters 182 else: 183 loader_iters = [self.dataloader_iter]

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\supporters.py:537, in CombinedLoaderIterator.loader_iters(self) 535 """Get the _loader_iters and create one if it is None.""" 536 if self._loader_iters is None: --> 537 self._loader_iters = self.create_loader_iters(self.loaders) 539 return self._loader_iters

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\supporters.py:577, in CombinedLoaderIterator.create_loader_iters(loaders) 568 """Create and return a collection of iterators from loaders. 569 570 Args: (...) 574 a collections of iterators 575 """ 576 # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences --> 577 return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\utilities\apply_func.py:95, in apply_to_collection(data, dtype, function, wrong_dtype, include_none, *args, *kwargs) 93 # Breaking condition 94 if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): ---> 95 return function(data, args, **kwargs) 97 elem_type = type(data) 99 # Recursively apply to collection items

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\torch\utils\data\dataloader.py:444, in DataLoader.iter(self) 442 return self._iterator 443 else: --> 444 return self._get_iterator()

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\torch\utils\data\dataloader.py:390, in DataLoader._get_iterator(self) 388 else: 389 self.check_worker_number_rationality() --> 390 return _MultiProcessingDataLoaderIter(self)

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\torch\utils\data\dataloader.py:1077, in _MultiProcessingDataLoaderIter.init(self, loader) 1070 w.daemon = True 1071 # NB: Process.start() actually take some time as it needs to 1072 # start a process and pass the arguments over via a pipe. 1073 # Therefore, we only add a worker to self._workers list after 1074 # it started, so that we do not call .join() if program dies 1075 # before it starts, and del tries to join but will get: 1076 # AssertionError: can only join a started process. -> 1077 w.start() 1078 self._index_queues.append(index_queue) 1079 self._workers.append(w)

File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\process.py:121, in BaseProcess.start(self) 118 assert not _current_process._config.get('daemon'), \ 119 'daemonic processes are not allowed to have children' 120 _cleanup() --> 121 self._popen = self._Popen(self) 122 self._sentinel = self._popen.sentinel 123 # Avoid a refcycle if the target function holds an indirect 124 # reference to the process object (see bpo-30775)

File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\context.py:224, in Process._Popen(process_obj) 222 @staticmethod 223 def _Popen(process_obj): --> 224 return _default_context.get_context().Process._Popen(process_obj)

File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\context.py:327, in SpawnProcess._Popen(process_obj) 324 @staticmethod 325 def _Popen(process_obj): 326 from .popen_spawn_win32 import Popen --> 327 return Popen(process_obj)

File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\popen_spawn_win32.py:93, in Popen.init(self, process_obj) 91 try: 92 reduction.dump(prep_data, to_child) ---> 93 reduction.dump(process_obj, to_child) 94 finally: 95 set_spawning_popen(None)

File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\reduction.py:60, in dump(obj, file, protocol) 58 def dump(obj, file, protocol=None): 59 '''Replacement for pickle.dump() using ForkingPickler.''' ---> 60 ForkingPickler(file, protocol).dump(obj)

TypeError: cannot pickle 'module' object


Which version of PyTorch would you recommend? pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116

RetroCirce commented 1 year ago

Hi,

I use pytorch 1.7.1 with cuda 10. But I think pytorch 1.11.0 with cuda 11 can also work. You can try python 3.7 or 3.8 to see if it can solve the problem.

Based on the log you show here, the problem is possibly the data loading in multi-processing. You can try to set num_worker = 0 to disable the multi-processing of the dataloader to see if this solves the problem.

JonathanFL commented 1 year ago

Setting num_worker = 0 worked. Thanks!