BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
324 stars 58 forks source link

I don't have the raw count data, how can I do to deal with the mod.train() error #204

Open FanZhang9 opened 2 years ago

FanZhang9 commented 2 years ago

In this issue, #120 cell2location require raw/untransformed/unnormalised counts for mapping, but I only have the TPM data.

vitkl commented 2 years ago

You need to get raw/untransformed/unnormalised counts. It is generally a good practice to keep that data in the course of analysis. Some normalisation workflows can be undone, for example:

normalised_data = data / total_per_cell * 10000
lognormalised_data = log(normalised_data + 1)

can be undone as

normalised_data = exp(lognormalised_data) - 1
data = normalised_data / 10000 * total_per_cell
data = data.astype(int) # make integer
lllldf commented 7 months ago

Thank you for your method. I would like to know what is the "total_per_cell" in the scRNA-seq data structure? Or it's just a number?

You need to get raw/untransformed/unnormalised counts. It is generally a good practice to keep that data in the course of analysis. Some normalisation workflows can be undone, for example:

normalised_data = data / total_per_cell * 10000
lognormalised_data = log(normalised_data + 1)

can be undone as

normalised_data = exp(lognormalised_data) - 1
data = normalised_data / 10000 * total_per_cell
data = data.astype(int) # make integer
vitkl commented 7 months ago

You need to know the total count per cell in the original raw count data matrix. Common workflows often save that in adata.obs - but its possible that the authors remove that column in which case there is not much you can do. You can also read this blog https://www.nxn.se/valent/2018/10/25/unscaling-scaled-counts-in-scrna-seq-data for more details.

hezuoxi commented 6 months ago

You need to know the total count per cell in the original raw count data matrix. Common workflows often save that in adata.obs - but its possible that the authors remove that column in which case there is not much you can do. You can also read this blog https://www.nxn.se/valent/2018/10/25/unscaling-scaled-counts-in-scrna-seq-data for more details.

Hello, i use the data from the 3 files produced by Cell Ranger, but it still show the same error, what should i do

vitkl commented 6 months ago

Which error do you see?

hezuoxi commented 6 months ago

Which error do you see?

This is the error

ValueError Traceback (most recent call last) File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:264, in Trace.compute_log_prob(self, site_filter) 263 try: --> 264 log_p = site["fn"].log_prob( 265 site["value"], *site["args"], **site["kwargs"] 266 ) 267 except ValueError as e:

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/distributions/conjugate.py:280, in GammaPoisson.log_prob(self, value) 279 if self._validate_args: --> 280 self._validate_sample(value) 281 post_value = self.concentration + value

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/torch/distributions/distribution.py:312, in Distribution._validate_sample(self, value) 311 if not valid.all(): --> 312 raise ValueError( 313 "Expected value argument " 314 f"({type(value).name} of shape {tuple(value.shape)}) " 315 f"to be within the support ({repr(support)}) " 316 f"of the distribution {repr(self)}, " 317 f"but found invalid values:\n{value}" 318 )

ValueError: Expected value argument (Tensor of shape (2500, 13228)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution GammaPoisson(), but found invalid values: tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.9727, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], device='cuda:0')

The above exception was the direct cause of the following exception:

ValueError Traceback (most recent call last) Cell In[12], line 1 ----> 1 mod.train(max_epochs=500, accelerator='gpu',train_size=1)

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/cell2location/models/reference/_reference_model.py:157, in RegressionModel.train(self, max_epochs, batch_size, train_size, lr, kwargs) 154 kwargs["train_size"] = train_size 155 kwargs["lr"] = lr --> 157 super().train(kwargs)

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:184, in PyroSviTrainMixin.train(self, max_epochs, use_gpu, accelerator, device, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, lr, training_plan, plan_kwargs, trainer_kwargs) 172 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) 174 runner = self._train_runner_cls( 175 self, 176 training_plan=training_plan, (...) 182 trainer_kwargs, 183 ) --> 184 return runner()

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/scvi/train/_trainrunner.py:99, in TrainRunner.call(self) 96 if hasattr(self.data_splitter, "n_val"): 97 self.training_plan.n_obs_validation = self.data_splitter.n_val ---> 99 self.trainer.fit(self.training_plan, self.data_splitter) 100 self._update_history() 102 # data splitter only gets these attrs after fit

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, *kwargs) 180 if isinstance(args[0], PyroTrainingPlan): 181 warnings.filterwarnings( 182 action="ignore", 183 category=UserWarning, 184 message="LightningModule.configure_optimizers returned None", 185 ) --> 186 super().fit(args, **kwargs)

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:532, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 530 self.strategy._lightning_module = model 531 _verify_strategy_supports_compile(model, self.strategy) --> 532 call._call_and_handle_interrupt( 533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 534 )

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:43, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 41 if trainer.strategy.launcher is not None: 42 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) ---> 43 return trainer_fn(args, kwargs) 45 except _TunerExitException: 46 _call_teardown_hook(trainer)

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:571, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 561 self._data_connector.attach_data( 562 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule 563 ) 565 ckpt_path = self._checkpoint_connector._select_ckpt_path( 566 self.state.fn, 567 ckpt_path, 568 model_provided=True, 569 model_connected=self.lightning_module is not None, 570 ) --> 571 self._run(model, ckpt_path=ckpt_path) 573 assert self.state.stopped 574 self.training = False

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:980, in Trainer._run(self, model, ckpt_path) 975 self._signal_connector.register_signal_handlers() 977 # ---------------------------- 978 # RUN THE TRAINER 979 # ---------------------------- --> 980 results = self._run_stage() 982 # ---------------------------- 983 # POST-Training CLEAN UP 984 # ---------------------------- 985 log.debug(f"{self.class.name}: trainer tearing down")

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1023, in Trainer._run_stage(self) 1021 self._run_sanity_check() 1022 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1023 self.fit_loop.run() 1024 return None 1025 raise RuntimeError(f"Unexpected state {self.state}")

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self) 200 try: 201 self.on_advance_start() --> 202 self.advance() 203 self.on_advance_end() 204 self._restarting = False

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:355, in _FitLoop.advance(self) 353 self._data_fetcher.setup(combined_loader) 354 with self.trainer.profiler.profile("run_training_epoch"): --> 355 self.epoch_loop.run(self._data_fetcher)

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:133, in _TrainingEpochLoop.run(self, data_fetcher) 131 while not self.done: 132 try: --> 133 self.advance(data_fetcher) 134 self.on_advance_end() 135 self._restarting = False

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:221, in _TrainingEpochLoop.advance(self, data_fetcher) 219 batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs) 220 else: --> 221 batch_output = self.manual_optimization.run(kwargs) 223 self.batch_progress.increment_processed() 225 # update non-plateau LR schedulers 226 # update epoch-interval ones only when we are at the end of training epoch

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/manual.py:91, in _ManualOptimization.run(self, kwargs) 89 self.on_run_start() 90 with suppress(StopIteration): # no loop to break at this level ---> 91 self.advance(kwargs) 92 self._restarting = False 93 return self.on_run_end()

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/manual.py:111, in _ManualOptimization.advance(self, kwargs) 108 trainer = self.trainer 110 # manually capture logged metrics --> 111 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values()) 112 del kwargs # release the batch from memory 113 self.trainer.strategy.post_training_step()

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:294, in _call_strategy_hook(trainer, hook_name, *args, *kwargs) 291 return None 293 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.class.name}.{hook_name}"): --> 294 output = fn(args, **kwargs) 296 # restore current_fx when nested context 297 pl_module._current_fx_name = prev_fx_name

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py:380, in Strategy.training_step(self, *args, *kwargs) 378 with self.precision_plugin.train_step_context(): 379 assert isinstance(self.model, TrainingStep) --> 380 return self.model.training_step(args, **kwargs)

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/scvi/train/_trainingplans.py:1041, in PyroTrainingPlan.training_step(self, batch, batch_idx) 1039 kwargs.update({"kl_weight": self.kl_weight}) 1040 # pytorch lightning requires a Tensor object for loss -> 1041 loss = torch.Tensor([self.svi.step(*args, **kwargs)]) 1043 _opt = self.optimizers() 1044 _opt.step()

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, *kwargs) 143 # get loss and compute gradients 144 with poutine.trace(param_only=True) as param_capture: --> 145 loss = self.loss_and_grads(self.model, self.guide, args, **kwargs) 147 params = set( 148 site["value"].unconstrained() for site in param_capture.trace.nodes.values() 149 ) 151 # actually perform gradient steps 152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs) 138 loss = 0.0 139 # grab a trace from the generator --> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): 141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle( 142 model_trace, guide_trace 143 ) 144 loss += loss_particle / self.num_particles

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs) 235 else: 236 for i in range(self.num_particles): --> 237 yield self._get_trace(model, guide, args, kwargs)

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs) 52 def _get_trace(self, model, guide, args, kwargs): 53 """ 54 Returns a single trace from the guide, and the model that is run 55 against it. 56 """ ---> 57 model_trace, guide_trace = get_importance_trace( 58 "flat", self.max_plate_nesting, model, guide, args, kwargs 59 ) 60 if is_validation_enabled(): 61 check_if_enumerated(guide_trace)

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach) 72 guide_trace = prune_subsample_sites(guide_trace) 73 model_trace = prune_subsample_sites(model_trace) ---> 75 model_trace.compute_log_prob() 76 guide_trace.compute_score_parts() 77 if is_validation_enabled():

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:270, in Trace.compute_log_prob(self, sitefilter) 268 , exc_value, traceback = sys.exc_info() 269 shapes = self.format_shapes(last_site=site["name"]) --> 270 raise ValueError( 271 "Error while computing log_prob at site '{}':\n{}\n{}".format( 272 name, exc_value, shapes 273 ) 274 ).with_traceback(traceback) from e 275 site["unscaled_log_prob"] = log_p 276 log_p = scale_and_mask(log_p, site["scale"], site["mask"])

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:264, in Trace.compute_log_prob(self, site_filter) 262 if "log_prob" not in site: 263 try: --> 264 log_p = site["fn"].logprob( 265 site["value"], *site["args"], **site["kwargs"] 266 ) 267 except ValueError as e: 268 , exc_value, traceback = sys.exc_info()

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/pyro/distributions/conjugate.py:280, in GammaPoisson.log_prob(self, value) 278 def log_prob(self, value): 279 if self._validate_args: --> 280 self._validate_sample(value) 281 post_value = self.concentration + value 282 return ( 283 -log_beta(self.concentration, value + 1) 284 - post_value.log() 285 + self.concentration self.rate.log() 286 - post_value (1 + self.rate).log() 287 )

File ~/anaconda3/envs/cell2/lib/python3.9/site-packages/torch/distributions/distribution.py:312, in Distribution._validate_sample(self, value) 310 valid = support.check(value) 311 if not valid.all(): --> 312 raise ValueError( 313 "Expected value argument " 314 f"({type(value).name} of shape {tuple(value.shape)}) " 315 f"to be within the support ({repr(support)}) " 316 f"of the distribution {repr(self)}, " 317 f"but found invalid values:\n{value}" 318 )

ValueError: Error while computing log_prob at site 'data_target': Expected value argument (Tensor of shape (2500, 13228)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution GammaPoisson(), but found invalid values: tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.9727, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], device='cuda:0') Trace Shapes:
Param Sites:
Sample Sites:
per_cluster_mu_fg dist | 15 13228 value | 15 13228 log_prob |
detection_mean_y_e dist | 85 1 value | 85 1 log_prob |
s_g_gene_add_alpha_hyp dist 1 1 |
value 1 1 |
log_prob 1 1 |
s_g_gene_add_mean dist | 85 1 value | 85 1 log_prob |
s_g_gene_add_alpha_e_inv dist | 85 1 value | 85 1 log_prob |
s_g_gene_add dist | 85 13228 value | 85 13228 log_prob |
alpha_g_phi_hyp dist 1 1 |
value 1 1 |
log_prob 1 1 |
alpha_g_inverse dist | 1 13228 value | 1 13228 log_prob |
data_target dist 2500 13228 |
value 2500 13228 |