BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
292 stars 54 forks source link

"OutOfMemoryError: CUDA out of memory." in CPU mode? #366

Open Li-ZhiD opened 1 month ago

Li-ZhiD commented 1 month ago

I trained a mod by using scRNA-seq with CPU mode (about 1 day). An error occurred when I try "Cell2location: spatial mapping" with CPU mode.

mod.train(
    max_epochs=30000,
    # train using full data (batch_size=None)
    batch_size=None,
    # use all data points in training because
    # we need to estimate cell abundance at all locations
    train_size=1,
    use_gpu=False,
    num_particles=1
)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[23], line 1
----> 1 mod.train(
      2     max_epochs=30000,
      3     # train using full data (batch_size=None)
      4     batch_size=None,
      5     # use all data points in training because
      6     # we need to estimate cell abundance at all locations
      7     train_size=1,
      8     use_gpu=False,
      9     num_particles=1
     10 )
     12 # plot ELBO loss history during training, removing first 100 epochs from the plot
     13 mod.plot_history(1000)

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/cell2location/models/_cell2location_model.py:209, in Cell2location.train(self, max_epochs, batch_size, train_size, lr, num_particles, scale_elbo, **kwargs)
    206         scale_elbo = 1.0 / (self.summary_stats["n_cells"] * self.summary_stats["n_vars"])
    207     kwargs["plan_kwargs"]["scale_elbo"] = scale_elbo
--> 209 super().train(**kwargs)

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:184, in PyroSviTrainMixin.train(self, max_epochs, use_gpu, accelerator, device, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, lr, training_plan, plan_kwargs, **trainer_kwargs)
    172 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())
    174 runner = self._train_runner_cls(
    175     self,
    176     training_plan=training_plan,
   (...)
    182     **trainer_kwargs,
    183 )
--> 184 return runner()

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/scvi/train/_trainrunner.py:99, in TrainRunner.__call__(self)
     96 if hasattr(self.data_splitter, "n_val"):
     97     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 99 self.trainer.fit(self.training_plan, self.data_splitter)
    100 self._update_history()
    102 # data splitter only gets these attrs after fit

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:532, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:43, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     41     if trainer.strategy.launcher is not None:
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 43     return trainer_fn(*args, **kwargs)
     45 except _TunerExitException:
     46     _call_teardown_hook(trainer)

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:571, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    561 self._data_connector.attach_data(
    562     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    563 )
    565 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    566     self.state.fn,
    567     ckpt_path,
    568     model_provided=True,
    569     model_connected=self.lightning_module is not None,
    570 )
--> 571 self._run(model, ckpt_path=ckpt_path)
    573 assert self.state.stopped
    574 self.training = False

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:941, in Trainer._run(self, model, ckpt_path)
    938 self.strategy.setup_environment()
    939 self.__setup_profiler()
--> 941 call._call_setup_hook(self)  # allow user to setup lightning_module in accelerator environment
    943 # check if we should delay restoring checkpoint till later
    944 if not self.strategy.restore_checkpoint_after_setup:

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:85, in _call_setup_hook(trainer)
     82 trainer.strategy.barrier("pre_setup")
     84 if trainer.datamodule is not None:
---> 85     _call_lightning_datamodule_hook(trainer, "setup", stage=fn)
     86 _call_callback_hooks(trainer, "setup", stage=fn)
     87 _call_lightning_module_hook(trainer, "setup", stage=fn)

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:166, in _call_lightning_datamodule_hook(trainer, hook_name, *args, **kwargs)
    164 if callable(fn):
    165     with trainer.profiler.profile(f"[LightningDataModule]{trainer.datamodule.__class__.__name__}.{hook_name}"):
--> 166         return fn(*args, **kwargs)
    167 return None

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/scvi/dataloaders/_data_splitting.py:431, in DeviceBackedDataSplitter.setup(self, stage)
    424     self.val_idx = (
    425         np.sort(self.val_idx) if len(self.val_idx) > 0 else self.val_idx
    426     )
    427     self.test_idx = (
    428         np.sort(self.test_idx) if len(self.test_idx) > 0 else self.test_idx
    429     )
--> 431 self.train_tensor_dict = self._get_tensor_dict(
    432     self.train_idx, device=self.device
    433 )
    434 self.test_tensor_dict = self._get_tensor_dict(self.test_idx, device=self.device)
    435 self.val_tensor_dict = self._get_tensor_dict(self.val_idx, device=self.device)

File /mnt/data1/ll/software/miniconda3/envs/cell2loc/lib/python3.9/site-packages/scvi/dataloaders/_data_splitting.py:453, in DeviceBackedDataSplitter._get_tensor_dict(self, indices, device)
    450         tensor_dict = batch
    452     for k, v in tensor_dict.items():
--> 453         tensor_dict[k] = v.to(device)
    455     return tensor_dict
    456 else:

OutOfMemoryError: CUDA out of memory. Tried to allocate 88.44 GiB. GPU 
vitkl commented 1 month ago

This line says that GPU was used.

GPU available: True (cuda), used: False

This likely means incompatibility of cell2location and scvi-tools. I recommend installing the GitHub version. The version uses 'accelerator' and 'device' arguments instead of 'use_gpu'.

That said, cell2location is going to take a very long time on CPU. If the data doesn't fit into the GPU memory, I recommend reading https://github.com/BayraktarLab/cell2location/issues/356 https://github.com/BayraktarLab/cell2location/issues/358 for tips on using large data.