Closed Rajarshi1001 closed 1 year ago
Any updates on this issue, please? having the same problem
How many genes are you using? The default batch_size might be too large. If you want to play with smaller batch sizes you should restart your kernel to clear the GPU memory.
I have tried changing the batch size several times and have also tried clearing the GPU memory but still the snippet throws the CUDA out of Memory error.
Are you able to share a reproducible example on Colab?
I thought there was a problem with GPU memory allocation (used earlier an RTX6000) but have tried also on an HPC with A100 that has 40G memory and still same error. However, I can make it run if I change the batch size. A small batch size (5-10) leads to an estimated training of 900h. When I increased it to 20000, it goes down to 13h (on A100). I am not sure what the consequences are of changing this parameter.
The single-cell RNA model was trained using 24K genes (the intersect with ST data is 19K). Would you suggest to trained it with fewer genes?
Thanks in advance.
I finally got the model trained after 14h, but now that I want to load the model I am having the same problem.
mod = cell2location.models.Cell2location.load('results/cell2location_detailed_states_updated/cell2location_map', adata)
INFO File
results/cell2location_detailed_states_updated/cell2location_map/model.pt
already downloaded
INFO Preparing underlying module for load
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/cruiz2/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:120: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/cruiz2/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
rank_zero_warn(
Epoch 1/30000: 0%| | 0/30000 [00:00<?, ?it/s]
/home/cruiz2/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/torch/distributions/gamma.py:71: UserWarning: Specified kernel cache directory could not be created! This disables kernel caching. Specified directory is /home/cruiz2/.cache/torch/kernels. This warning will appear only once per process. (Triggered internally at ../aten/src/ATen/native/cuda/jit_utils.cpp:860.)
self.rate * value - torch.lgamma(self.concentration))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/model/base/_base_model.py:617, in BaseModelClass.load(cls, dir_path, adata, use_gpu, prefix, backup_url)
616 try:
--> 617 model.module.load_state_dict(model_state_dict)
618 except RuntimeError as err:
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/torch/nn/modules/module.py:1497, in Module.load_state_dict(self, state_dict, strict)
1496 if len(error_msgs) > 0:
-> 1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1498 self.__class__.__name__, "\n\t".join(error_msgs)))
1499 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for Cell2locationBaseModule:
Unexpected key(s) in state_dict: "_guide.locs.m_g_mean_unconstrained", "_guide.locs.m_g_alpha_e_inv_unconstrained", "_guide.locs.m_g_unconstrained", "_guide.locs.n_s_cells_per_location_unconstrained", "_guide.locs.b_s_groups_per_location_unconstrained", "_guide.locs.z_sr_groups_factors_unconstrained", "_guide.locs.k_r_factors_per_groups_unconstrained", "_guide.locs.x_fr_group2fact_unconstrained", "_guide.locs.w_sf_unconstrained", "_guide.locs.detection_mean_y_e_unconstrained", "_guide.locs.detection_y_s_unconstrained", "_guide.locs.s_g_gene_add_alpha_hyp_unconstrained", "_guide.locs.s_g_gene_add_mean_unconstrained", "_guide.locs.s_g_gene_add_alpha_e_inv_unconstrained", "_guide.locs.s_g_gene_add_unconstrained", "_guide.locs.alpha_g_phi_hyp_unconstrained", "_guide.locs.alpha_g_inverse_unconstrained", "_guide.scales.m_g_mean_unconstrained", "_guide.scales.m_g_alpha_e_inv_unconstrained", "_guide.scales.m_g_unconstrained", "_guide.scales.n_s_cells_per_location_unconstrained", "_guide.scales.b_s_groups_per_location_unconstrained", "_guide.scales.z_sr_groups_factors_unconstrained", "_guide.scales.k_r_factors_per_groups_unconstrained", "_guide.scales.x_fr_group2fact_unconstrained", "_guide.scales.w_sf_unconstrained", "_guide.scales.detection_mean_y_e_unconstrained", "_guide.scales.detection_y_s_unconstrained", "_guide.scales.s_g_gene_add_alpha_hyp_unconstrained", "_guide.scales.s_g_gene_add_mean_unconstrained", "_guide.scales.s_g_gene_add_alpha_e_inv_unconstrained", "_guide.scales.s_g_gene_add_unconstrained", "_guide.scales.alpha_g_phi_hyp_unconstrained", "_guide.scales.alpha_g_inverse_unconstrained".
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 mod = cell2location.models.Cell2location.load('results/gbm_all_cell2location_detailed_states_updated/cell2location_map',
2 adata)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/model/base/_base_model.py:622, in BaseModelClass.load(cls, dir_path, adata, use_gpu, prefix, backup_url)
620 old_history = model.history_.copy()
621 logger.info("Preparing underlying module for load")
--> 622 model.train(max_steps=1)
623 model.history_ = old_history
624 pyro.clear_param_store()
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/cell2location/models/_cell2location_model.py:209, in Cell2location.train(self, max_epochs, batch_size, train_size, lr, num_particles, scale_elbo, **kwargs)
206 scale_elbo = 1.0 / (self.summary_stats["n_cells"] * self.summary_stats["n_genes"])
207 kwargs["plan_kwargs"]["scale_elbo"] = scale_elbo
--> 209 super().train(**kwargs)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:146, in PyroSviTrainMixin.train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, lr, training_plan, plan_kwargs, **trainer_kwargs)
136 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())
138 runner = TrainRunner(
139 self,
140 training_plan=training_plan,
(...)
144 **trainer_kwargs,
145 )
--> 146 return runner()
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/train/_trainrunner.py:74, in TrainRunner.__call__(self)
71 if hasattr(self.data_splitter, "n_val"):
72 self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 74 self.trainer.fit(self.training_plan, self.data_splitter)
75 self._update_history()
77 # data splitter only gets these attrs after fit
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
180 if isinstance(args[0], PyroTrainingPlan):
181 warnings.filterwarnings(
182 action="ignore",
183 category=UserWarning,
184 message="`LightningModule.configure_optimizers` returned `None`",
185 )
--> 186 super().fit(*args, **kwargs)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
735 rank_zero_deprecation(
736 "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
737 " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
738 )
739 train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
741 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
742 )
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
675 r"""
676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
677 as all errors should funnel through them
(...)
682 **kwargs: keyword arguments to be passed to `trainer_fn`
683 """
684 try:
--> 685 return trainer_fn(*args, **kwargs)
686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
687 except KeyboardInterrupt as exception:
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
775 # TODO: ckpt_path only in v1.7
776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
779 assert self.state.stopped
780 self.training = False
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1199, in Trainer._run(self, model, ckpt_path)
1196 self.checkpoint_connector.resume_end()
1198 # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1199 self._dispatch()
1201 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
1202 self._post_dispatch()
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1279, in Trainer._dispatch(self)
1277 self.training_type_plugin.start_predicting(self)
1278 else:
-> 1279 self.training_type_plugin.start_training(self)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:202, in TrainingTypePlugin.start_training(self, trainer)
200 def start_training(self, trainer: "pl.Trainer") -> None:
201 # double dispatch to initiate the training loop
--> 202 self._results = trainer.run_stage()
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1289, in Trainer.run_stage(self)
1287 if self.predicting:
1288 return self._run_predict()
-> 1289 return self._run_train()
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1319, in Trainer._run_train(self)
1317 self.fit_loop.trainer = self
1318 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1319 self.fit_loop.run()
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:234, in FitLoop.advance(self)
231 data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)
233 with self.trainer.profiler.profile("run_training_epoch"):
--> 234 self.epoch_loop.run(data_fetcher)
236 # the global step is manually decreased here due to backwards compatibility with existing loggers
237 # as they expect that the same step is used when logging epoch end metrics even when the batch loop has
238 # finished. this means the attribute does not exactly track the number of optimizer steps applied.
239 # TODO(@carmocca): deprecate and rename so users don't get confused
240 self.global_step -= 1
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:193, in TrainingEpochLoop.advance(self, *args, **kwargs)
190 self.batch_progress.increment_started()
192 with self.trainer.profiler.profile("run_training_batch"):
--> 193 batch_output = self.batch_loop.run(batch, batch_idx)
195 self.batch_progress.increment_processed()
197 # update non-plateau LR schedulers
198 # update epoch-interval ones only when we are at the end of training epoch
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py:90, in TrainingBatchLoop.advance(self, batch, batch_idx)
88 outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
89 else:
---> 90 outputs = self.manual_loop.run(split_batch, batch_idx)
91 if outputs:
92 # automatic: can be empty if all optimizers skip their batches
93 # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
94 # then `advance` doesn't finish and an empty dict is returned
95 self._outputs.append(outputs)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:145, in Loop.run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/manual_loop.py:111, in ManualOptimization.advance(self, batch, batch_idx)
109 lightning_module._current_fx_name = "training_step"
110 with self.trainer.profiler.profile("training_step"):
--> 111 training_step_output = self.trainer.accelerator.training_step(step_kwargs)
112 self.trainer.training_type_plugin.post_training_step()
114 del step_kwargs
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py:219, in Accelerator.training_step(self, step_kwargs)
214 """The actual training step.
215
216 See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
217 """
218 with self.precision_plugin.train_step_context():
--> 219 return self.training_type_plugin.training_step(*step_kwargs.values())
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:213, in TrainingTypePlugin.training_step(self, *args, **kwargs)
212 def training_step(self, *args, **kwargs):
--> 213 return self.model.training_step(*args, **kwargs)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/train/_trainingplans.py:743, in PyroTrainingPlan.training_step(self, batch, batch_idx)
741 kwargs.update({"kl_weight": self.kl_weight})
742 # pytorch lightning requires a Tensor object for loss
--> 743 loss = torch.Tensor([self.svi.step(*args, **kwargs)])
745 return {"loss": loss}
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:157, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
153 if trainable_params and getattr(
154 surrogate_loss_particle, "requires_grad", False
155 ):
156 surrogate_loss_particle = surrogate_loss_particle / self.num_particles
--> 157 surrogate_loss_particle.backward(retain_graph=self.retain_graph)
158 warn_if_nan(loss, "loss")
159 return loss
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/torch/_tensor.py:363, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
354 if has_torch_function_unary(self):
355 return handle_torch_function(
356 Tensor.backward,
357 (self,),
(...)
361 create_graph=create_graph,
362 inputs=inputs)
--> 363 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/torch/autograd/__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
168 retain_graph = create_graph
170 # The reason we repeat same the comment below is that
171 # some Python versions print out the first line of a multi-line function
172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
174 tensors, grad_tensors_, retain_graph, create_graph, inputs,
175 allow_unreachable=True, accumulate_grad=True)
RuntimeError: CUDA out of memory. Tried to allocate 2.28 GiB (GPU 0; 39.44 GiB total capacity; 36.77 GiB already allocated; 91.62 MiB free; 37.15 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Any help, please!
I also have the same problem OutOfMemoryError: CUDA out of memory. Tried to allocate 4.81 GiB (GPU 1; 79.21 GiB total capacity; 72.93 GiB already allocated; 4.35 GiB free; 73.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
also this was after running torch.cuda.empty_cache()
.
The output for the cuda memory is, which looks ok to me :
If your problem is with loading the regression model or training the regression model - please check that you are using latest cell2location.
If your problem is with the main cell2location model, your dataset could be too large - its realistic to use 16k genes * about 50k locations on A100 80GB GPU.
@vitkl Thank you for the explanation. I went into trouble CUDA memory problem when running the main cell2location model. My dataset was indeed bigger (17k genes over 70k locations). After reducing the genes to 13k, CUDA is happy again.
Please use the template below to post a question to https://discourse.scverse.org/c/ecosytem/cell2location/.
Problem
...
N_cells_per_location
anddetection_alpha
.batch_key
for reference NB regression.I have actually gone through the mousebrain cell2location tutorial and I was planning to run the newer version of the Cell2location model on this dataset. The data prerocessing part was done successfully but I was having some memory issues in the training of the RegerssionModel.
Here is the block of code that produces rhe error:
here I have tried to different batch_size values (even 1) but still it produces the given error:
I have also stepped through the portion of the code which runs out of memory. The corresponding file was /usr/local/lib/python3.7/dist-packages/scvi/model/base/_pyromixin.py while running the code in google colab.
After stepping through this TrainRunner, the CUDA out of memory issue is displayed. I have actually used pdb to step through each line of this class. Could you please provide a reason for this memory issue since I am not able to train the Regression model with the *MouseBrain data. or suggest any ways to overcome this memory issue since decreasing the batch_size** parameter doesn't seem to work in this case.