BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
307 stars 57 forks source link

CUDA out of Memory while training the RegressionModel #168

Closed Rajarshi1001 closed 1 year ago

Rajarshi1001 commented 2 years ago

Please use the template below to post a question to https://discourse.scverse.org/c/ecosytem/cell2location/.

Problem

...

I have actually gone through the mousebrain cell2location tutorial and I was planning to run the newer version of the Cell2location model on this dataset. The data prerocessing part was done successfully but I was having some memory issues in the training of the RegerssionModel.

Here is the block of code that produces rhe error:

mod.train(max_epochs=250, batch_size=4, train_size=0.2,validation_size=0.2, lr=0.002, use_gpu=True)

here I have tried to different batch_size values (even 1) but still it produces the given error: image

I have also stepped through the portion of the code which runs out of memory. The corresponding file was /usr/local/lib/python3.7/dist-packages/scvi/model/base/_pyromixin.py while running the code in google colab.

class PyroSviTrainMixin:
    """
    Mixin class for training Pyro models.

    Training using minibatches and using full data (copies data to GPU only once).
    """

    def train(
        self,
        max_epochs: Optional[int] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
        train_size: float = 0.9,
        validation_size: Optional[float] = None,
        batch_size: int = 128,
        early_stopping: bool = False,
        lr: Optional[float] = None,
        training_plan: PyroTrainingPlan = PyroTrainingPlan,
        plan_kwargs: Optional[dict] = None,
        **trainer_kwargs,
    ):
        """
        Train the model.

        Parameters
        ----------
        max_epochs
            Number of passes through the dataset. If `None`, defaults to
            `np.min([round((20000 / n_cells) * 400), 400])`
        use_gpu
            Use default GPU if available (if None or True), or index of GPU to use (if int),
            or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
        train_size
            Size of training set in the range [0.0, 1.0].
        validation_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + validation_size < 1`, the remaining cells belong to a test set.
        batch_size
            Minibatch size to use during training. If `None`, no minibatching occurs and all
            data is copied to device (e.g., GPU).
        early_stopping
            Perform early stopping. Additional arguments can be passed in `**kwargs`.
            See :class:`~scvi.train.Trainer` for further options.
        lr
            Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`).
            Specifying optimiser via plan_kwargs overrides this choice of lr.
        training_plan
            Training plan :class:`~scvi.train.PyroTrainingPlan`.
        plan_kwargs
            Keyword args for :class:`~scvi.train.PyroTrainingPlan`. Keyword arguments passed to
            `train()` will overwrite values present in `plan_kwargs`, when appropriate.
        **trainer_kwargs
            Other keyword args for :class:`~scvi.train.Trainer`.
        """
        if max_epochs is None:
            n_obs = self.adata.n_obs
            max_epochs = np.min([round((20000 / n_obs) * 1000), 1000])

        plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
        if lr is not None and "optim" not in plan_kwargs.keys():
            plan_kwargs.update({"optim_kwargs": {"lr": lr}})

        if batch_size is None:
            # use data splitter which moves data to GPU once
            data_splitter = DeviceBackedDataSplitter(
                self.adata_manager,
                train_size=train_size,
                validation_size=validation_size,
                batch_size=batch_size,
                use_gpu=use_gpu,
            )
        else:
            data_splitter = DataSplitter(
                self.adata_manager,
                train_size=train_size,
                validation_size=validation_size,
                batch_size=batch_size,
                use_gpu=use_gpu,
            )
        training_plan = training_plan(self.module, **plan_kwargs)

        es = "early_stopping"
        trainer_kwargs[es] = (
            early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
        )

        if "callbacks" not in trainer_kwargs.keys():
            trainer_kwargs["callbacks"] = []
        trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())

        runner = TrainRunner(
            self,
            training_plan=training_plan,
            data_splitter=data_splitter,
            max_epochs=max_epochs,
            use_gpu=use_gpu,
            **trainer_kwargs,
        )
        return runner()

After stepping through this TrainRunner, the CUDA out of memory issue is displayed. I have actually used pdb to step through each line of this class. Could you please provide a reason for this memory issue since I am not able to train the Regression model with the *MouseBrain data. or suggest any ways to overcome this memory issue since decreasing the batch_size** parameter doesn't seem to work in this case.

ccruizm commented 2 years ago

Any updates on this issue, please? having the same problem

adamgayoso commented 2 years ago

How many genes are you using? The default batch_size might be too large. If you want to play with smaller batch sizes you should restart your kernel to clear the GPU memory.

Rajarshi1001 commented 2 years ago

I have tried changing the batch size several times and have also tried clearing the GPU memory but still the snippet throws the CUDA out of Memory error.

adamgayoso commented 2 years ago

Are you able to share a reproducible example on Colab?

ccruizm commented 2 years ago

I thought there was a problem with GPU memory allocation (used earlier an RTX6000) but have tried also on an HPC with A100 that has 40G memory and still same error. However, I can make it run if I change the batch size. A small batch size (5-10) leads to an estimated training of 900h. When I increased it to 20000, it goes down to 13h (on A100). I am not sure what the consequences are of changing this parameter.

The single-cell RNA model was trained using 24K genes (the intersect with ST data is 19K). Would you suggest to trained it with fewer genes?

Thanks in advance.

ccruizm commented 2 years ago

I finally got the model trained after 14h, but now that I want to load the model I am having the same problem. mod = cell2location.models.Cell2location.load('results/cell2location_detailed_states_updated/cell2location_map', adata)

INFO     File                                                                                
         results/cell2location_detailed_states_updated/cell2location_map/model.pt    
         already downloaded                                                                  
INFO     Preparing underlying module for load                                                
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/cruiz2/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:120: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/cruiz2/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 1/30000:   0%|                                                 | 0/30000 [00:00<?, ?it/s]
/home/cruiz2/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/torch/distributions/gamma.py:71: UserWarning: Specified kernel cache directory could not be created! This disables kernel caching. Specified directory is /home/cruiz2/.cache/torch/kernels. This warning will appear only once per process. (Triggered internally at  ../aten/src/ATen/native/cuda/jit_utils.cpp:860.)
  self.rate * value - torch.lgamma(self.concentration))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/model/base/_base_model.py:617, in BaseModelClass.load(cls, dir_path, adata, use_gpu, prefix, backup_url)
    616 try:
--> 617     model.module.load_state_dict(model_state_dict)
    618 except RuntimeError as err:

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/torch/nn/modules/module.py:1497, in Module.load_state_dict(self, state_dict, strict)
   1496 if len(error_msgs) > 0:
-> 1497     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1498                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1499 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Cell2locationBaseModule:
    Unexpected key(s) in state_dict: "_guide.locs.m_g_mean_unconstrained", "_guide.locs.m_g_alpha_e_inv_unconstrained", "_guide.locs.m_g_unconstrained", "_guide.locs.n_s_cells_per_location_unconstrained", "_guide.locs.b_s_groups_per_location_unconstrained", "_guide.locs.z_sr_groups_factors_unconstrained", "_guide.locs.k_r_factors_per_groups_unconstrained", "_guide.locs.x_fr_group2fact_unconstrained", "_guide.locs.w_sf_unconstrained", "_guide.locs.detection_mean_y_e_unconstrained", "_guide.locs.detection_y_s_unconstrained", "_guide.locs.s_g_gene_add_alpha_hyp_unconstrained", "_guide.locs.s_g_gene_add_mean_unconstrained", "_guide.locs.s_g_gene_add_alpha_e_inv_unconstrained", "_guide.locs.s_g_gene_add_unconstrained", "_guide.locs.alpha_g_phi_hyp_unconstrained", "_guide.locs.alpha_g_inverse_unconstrained", "_guide.scales.m_g_mean_unconstrained", "_guide.scales.m_g_alpha_e_inv_unconstrained", "_guide.scales.m_g_unconstrained", "_guide.scales.n_s_cells_per_location_unconstrained", "_guide.scales.b_s_groups_per_location_unconstrained", "_guide.scales.z_sr_groups_factors_unconstrained", "_guide.scales.k_r_factors_per_groups_unconstrained", "_guide.scales.x_fr_group2fact_unconstrained", "_guide.scales.w_sf_unconstrained", "_guide.scales.detection_mean_y_e_unconstrained", "_guide.scales.detection_y_s_unconstrained", "_guide.scales.s_g_gene_add_alpha_hyp_unconstrained", "_guide.scales.s_g_gene_add_mean_unconstrained", "_guide.scales.s_g_gene_add_alpha_e_inv_unconstrained", "_guide.scales.s_g_gene_add_unconstrained", "_guide.scales.alpha_g_phi_hyp_unconstrained", "_guide.scales.alpha_g_inverse_unconstrained". 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 mod = cell2location.models.Cell2location.load('results/gbm_all_cell2location_detailed_states_updated/cell2location_map', 
      2                                               adata)

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/model/base/_base_model.py:622, in BaseModelClass.load(cls, dir_path, adata, use_gpu, prefix, backup_url)
    620 old_history = model.history_.copy()
    621 logger.info("Preparing underlying module for load")
--> 622 model.train(max_steps=1)
    623 model.history_ = old_history
    624 pyro.clear_param_store()

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/cell2location/models/_cell2location_model.py:209, in Cell2location.train(self, max_epochs, batch_size, train_size, lr, num_particles, scale_elbo, **kwargs)
    206         scale_elbo = 1.0 / (self.summary_stats["n_cells"] * self.summary_stats["n_genes"])
    207     kwargs["plan_kwargs"]["scale_elbo"] = scale_elbo
--> 209 super().train(**kwargs)

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:146, in PyroSviTrainMixin.train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, lr, training_plan, plan_kwargs, **trainer_kwargs)
    136 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())
    138 runner = TrainRunner(
    139     self,
    140     training_plan=training_plan,
   (...)
    144     **trainer_kwargs,
    145 )
--> 146 return runner()

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/train/_trainrunner.py:74, in TrainRunner.__call__(self)
     71 if hasattr(self.data_splitter, "n_val"):
     72     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 74 self.trainer.fit(self.training_plan, self.data_splitter)
     75 self._update_history()
     77 # data splitter only gets these attrs after fit

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
    735     rank_zero_deprecation(
    736         "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
    737         " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
    738     )
    739     train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
    741     self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    742 )

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    675 r"""
    676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
    677 as all errors should funnel through them
   (...)
    682     **kwargs: keyword arguments to be passed to `trainer_fn`
    683 """
    684 try:
--> 685     return trainer_fn(*args, **kwargs)
    686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    687 except KeyboardInterrupt as exception:

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    775 # TODO: ckpt_path only in v1.7
    776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
    779 assert self.state.stopped
    780 self.training = False

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1199, in Trainer._run(self, model, ckpt_path)
   1196 self.checkpoint_connector.resume_end()
   1198 # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1199 self._dispatch()
   1201 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
   1202 self._post_dispatch()

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1279, in Trainer._dispatch(self)
   1277     self.training_type_plugin.start_predicting(self)
   1278 else:
-> 1279     self.training_type_plugin.start_training(self)

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:202, in TrainingTypePlugin.start_training(self, trainer)
    200 def start_training(self, trainer: "pl.Trainer") -> None:
    201     # double dispatch to initiate the training loop
--> 202     self._results = trainer.run_stage()

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1289, in Trainer.run_stage(self)
   1287 if self.predicting:
   1288     return self._run_predict()
-> 1289 return self._run_train()

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1319, in Trainer._run_train(self)
   1317 self.fit_loop.trainer = self
   1318 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1319     self.fit_loop.run()

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, **kwargs)
    143 try:
    144     self.on_advance_start(*args, **kwargs)
--> 145     self.advance(*args, **kwargs)
    146     self.on_advance_end()
    147     self.restarting = False

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:234, in FitLoop.advance(self)
    231 data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)
    233 with self.trainer.profiler.profile("run_training_epoch"):
--> 234     self.epoch_loop.run(data_fetcher)
    236     # the global step is manually decreased here due to backwards compatibility with existing loggers
    237     # as they expect that the same step is used when logging epoch end metrics even when the batch loop has
    238     # finished. this means the attribute does not exactly track the number of optimizer steps applied.
    239     # TODO(@carmocca): deprecate and rename so users don't get confused
    240     self.global_step -= 1

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, **kwargs)
    143 try:
    144     self.on_advance_start(*args, **kwargs)
--> 145     self.advance(*args, **kwargs)
    146     self.on_advance_end()
    147     self.restarting = False

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:193, in TrainingEpochLoop.advance(self, *args, **kwargs)
    190     self.batch_progress.increment_started()
    192     with self.trainer.profiler.profile("run_training_batch"):
--> 193         batch_output = self.batch_loop.run(batch, batch_idx)
    195 self.batch_progress.increment_processed()
    197 # update non-plateau LR schedulers
    198 # update epoch-interval ones only when we are at the end of training epoch

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, **kwargs)
    143 try:
    144     self.on_advance_start(*args, **kwargs)
--> 145     self.advance(*args, **kwargs)
    146     self.on_advance_end()
    147     self.restarting = False

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py:90, in TrainingBatchLoop.advance(self, batch, batch_idx)
     88     outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
     89 else:
---> 90     outputs = self.manual_loop.run(split_batch, batch_idx)
     91 if outputs:
     92     # automatic: can be empty if all optimizers skip their batches
     93     # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
     94     # then `advance` doesn't finish and an empty dict is returned
     95     self._outputs.append(outputs)

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, **kwargs)
    143 try:
    144     self.on_advance_start(*args, **kwargs)
--> 145     self.advance(*args, **kwargs)
    146     self.on_advance_end()
    147     self.restarting = False

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/manual_loop.py:111, in ManualOptimization.advance(self, batch, batch_idx)
    109 lightning_module._current_fx_name = "training_step"
    110 with self.trainer.profiler.profile("training_step"):
--> 111     training_step_output = self.trainer.accelerator.training_step(step_kwargs)
    112     self.trainer.training_type_plugin.post_training_step()
    114 del step_kwargs

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py:219, in Accelerator.training_step(self, step_kwargs)
    214 """The actual training step.
    215 
    216 See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
    217 """
    218 with self.precision_plugin.train_step_context():
--> 219     return self.training_type_plugin.training_step(*step_kwargs.values())

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:213, in TrainingTypePlugin.training_step(self, *args, **kwargs)
    212 def training_step(self, *args, **kwargs):
--> 213     return self.model.training_step(*args, **kwargs)

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/train/_trainingplans.py:743, in PyroTrainingPlan.training_step(self, batch, batch_idx)
    741     kwargs.update({"kl_weight": self.kl_weight})
    742 # pytorch lightning requires a Tensor object for loss
--> 743 loss = torch.Tensor([self.svi.step(*args, **kwargs)])
    745 return {"loss": loss}

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:157, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    153     if trainable_params and getattr(
    154         surrogate_loss_particle, "requires_grad", False
    155     ):
    156         surrogate_loss_particle = surrogate_loss_particle / self.num_particles
--> 157         surrogate_loss_particle.backward(retain_graph=self.retain_graph)
    158 warn_if_nan(loss, "loss")
    159 return loss

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/torch/_tensor.py:363, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    354 if has_torch_function_unary(self):
    355     return handle_torch_function(
    356         Tensor.backward,
    357         (self,),
   (...)
    361         create_graph=create_graph,
    362         inputs=inputs)
--> 363 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/torch/autograd/__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    168     retain_graph = create_graph
    170 # The reason we repeat same the comment below is that
    171 # some Python versions print out the first line of a multi-line function
    172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

RuntimeError: CUDA out of memory. Tried to allocate 2.28 GiB (GPU 0; 39.44 GiB total capacity; 36.77 GiB already allocated; 91.62 MiB free; 37.15 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Any help, please!

AshleyLu commented 1 year ago

I also have the same problem OutOfMemoryError: CUDA out of memory. Tried to allocate 4.81 GiB (GPU 1; 79.21 GiB total capacity; 72.93 GiB already allocated; 4.35 GiB free; 73.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF also this was after running torch.cuda.empty_cache().

The output for the cuda memory is, which looks ok to me : image

vitkl commented 1 year ago

If your problem is with loading the regression model or training the regression model - please check that you are using latest cell2location.

If your problem is with the main cell2location model, your dataset could be too large - its realistic to use 16k genes * about 50k locations on A100 80GB GPU.

AshleyLu commented 1 year ago

@vitkl Thank you for the explanation. I went into trouble CUDA memory problem when running the main cell2location model. My dataset was indeed bigger (17k genes over 70k locations). After reducing the genes to 13k, CUDA is happy again.