BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
315 stars 57 forks source link

How do I set up multi-GPU parallel training? #388

Open noob000007 opened 1 week ago

noob000007 commented 1 week ago

How do I set up multi-GPU parallel training?

noob000007 commented 1 week ago

if i use ----> 9 super(RegressionModel, mod).train(max_epochs=100, batch_size=2500, train_size=1, lr=0.002,device=4) then will:

{ "name": "ProcessExitedException", "message": "process 3 terminated with signal SIGSEGV", "stack": "--------------------------------------------------------------------------- ProcessExitedException Traceback (most recent call last) Cell In[13], line 9 1 # Use all data for training (validation not implemented yet, train_size=1) 2 # mod.train( 3 # max_epochs=100, (...) 7 # ) 8 # ----> 9 super(RegressionModel, mod).train(max_epochs=100, batch_size=2500, train_size=1, lr=0.002,device=4)

File ~/anaconda3/envs/cellana/envs/cell2loc_env/lib/python3.10/site-packages/scvi/model/base/_pyromixin.py:194, in PyroSviTrainMixin.train(self, max_epochs, accelerator, device, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, lr, training_plan, datasplitter_kwargs, plan_kwargs, trainer_kwargs) 183 trainer_kwargs[\"callbacks\"].append(PyroJitGuideWarmup()) 185 runner = self._train_runner_cls( 186 self, 187 training_plan=training_plan, (...) 192 trainer_kwargs, 193 ) --> 194 return runner()

File ~/anaconda3/envs/cellana/envs/cell2loc_env/lib/python3.10/site-packages/scvi/train/_trainrunner.py:96, in TrainRunner.call(self) 93 if hasattr(self.data_splitter, \"n_val\"): 94 self.training_plan.n_obs_validation = self.data_splitter.n_val ---> 96 self.trainer.fit(self.training_plan, self.data_splitter) 97 self._update_history() 99 # data splitter only gets these attrs after fit

File ~/anaconda3/envs/cellana/envs/cell2loc_env/lib/python3.10/site-packages/scvi/train/_trainer.py:201, in Trainer.fit(self, *args, *kwargs) 195 if isinstance(args[0], PyroTrainingPlan): 196 warnings.filterwarnings( 197 action=\"ignore\", 198 category=UserWarning, 199 message=\"LightningModule.configure_optimizers returned None\", 200 ) --> 201 super().fit(args, **kwargs)

File ~/anaconda3/envs/cellana/envs/cell2loc_env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 536 self.state.status = TrainerStatus.RUNNING 537 self.training = True --> 538 call._call_and_handle_interrupt( 539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 540 )

File ~/anaconda3/envs/cellana/envs/cell2loc_env/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:46, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 44 try: 45 if trainer.strategy.launcher is not None: ---> 46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) 47 return trainer_fn(args, kwargs) 49 except _TunerExitException:

File ~/anaconda3/envs/cellana/envs/cell2loc_env/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py:144, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs) 136 process_context = mp.start_processes( 137 self._wrapping_function, 138 args=process_args, (...) 141 join=False, # we will join ourselves to get the process references 142 ) 143 self.procs = process_context.processes --> 144 while not process_context.join(): 145 pass 147 worker_output = return_queue.get()

File ~/anaconda3/envs/cellana/envs/cell2loc_env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:170, in ProcessContext.join(self, timeout) 168 except ValueError: 169 name = f\"<Unknown signal {-exitcode}>\" --> 170 raise ProcessExitedException( 171 \"process %d terminated with signal %s\" % (error_index, name), 172 error_index=error_index, 173 error_pid=failed_process.pid, 174 exit_code=exitcode, 175 signal_name=name, 176 ) 177 else: 178 raise ProcessExitedException( 179 \"process %d terminated with exit code %d\" % (error_index, exitcode), 180 error_index=error_index, 181 error_pid=failed_process.pid, 182 exit_code=exitcode, 183 )

ProcessExitedException: process 3 terminated with signal SIGSEGV" }

vitkl commented 1 week ago

Multi-GPU training is not needed for the regression model (batch size is never too large for most GPUs) and quite non-trivial for the cell2location model. It is non-trivial because you need to keep location-specific parameters not just location-specific data on different GPU devices and it has to be full data rather than minibatch training.