BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
320 stars 58 forks source link

Error while computing log_prob at site 'data_target' #313

Open Shiyc-Lab opened 1 year ago

Shiyc-Lab commented 1 year ago

I encountered a error when running mod.train(max_ephocs=250, use_gpu=True) like this: '''

ValueError Traceback (most recent call last) File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter) 229 try: --> 230 log_p = site["fn"].log_prob( 231 site["value"], *site["args"], **site["kwargs"] 232 ) 233 except ValueError as e:

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/distributions/conjugate.py:277, in GammaPoisson.log_prob(self, value) 276 if self._validate_args: --> 277 self._validate_sample(value) 278 post_value = self.concentration + value

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/torch/distributions/distribution.py:300, in Distribution._validate_sample(self, value) 299 if not valid.all(): --> 300 raise ValueError( 301 "Expected value argument " 302 f"({type(value).name} of shape {tuple(value.shape)}) " 303 f"to be within the support ({repr(support)}) " 304 f"of the distribution {repr(self)}, " 305 f"but found invalid values:\n{value}" 306 )

ValueError: Expected value argument (Tensor of shape (2500, 1776)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution GammaPoisson(), but found invalid values: tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 2.1972], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.6931, ..., 0.0000, 0.0000, 0.0000], [2.0794, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [3.2958, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.6931]], device='cuda:0')

The above exception was the direct cause of the following exception:

ValueError Traceback (most recent call last) /home/shiyclab01/data/macaq_cortex_bgi/cell2loc.ipynb Cell 22 line 1 ----> 1](vscode-notebook-cell://ssh-remote%2Bidw01/home/shiyclab01/data/macaq_cortex_bgi/cell2loc.ipynb#X33sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'%3E1%3C/a%3E) mod.train(max_epochs=200,use_gpu=True)

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/cell2location/models/reference/_reference_model.py:157, in RegressionModel.train(self, max_epochs, batch_size, train_size, lr, kwargs) 154 kwargs["train_size"] = train_size 155 kwargs["lr"] = lr --> 157 super().train(kwargs)

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/scvi/model/base/_pyromixin.py:172, in PyroSviTrainMixin.train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, lr, training_plan, plan_kwargs, trainer_kwargs) 162 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) 164 runner = self._train_runner_cls( 165 self, 166 training_plan=training_plan, (...) 170 trainer_kwargs, 171 ) --> 172 return runner()

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/scvi/train/_trainrunner.py:82, in TrainRunner.call(self) 79 if hasattr(self.data_splitter, "n_val"): 80 self.training_plan.n_obs_validation = self.data_splitter.n_val ---> 82 self.trainer.fit(self.training_plan, self.data_splitter) 83 self._update_history() 85 # data splitter only gets these attrs after fit

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/scvi/train/_trainer.py:193, in Trainer.fit(self, *args, *kwargs) 187 if isinstance(args[0], PyroTrainingPlan): 188 warnings.filterwarnings( 189 action="ignore", 190 category=UserWarning, 191 message="LightningModule.configure_optimizers returned None", 192 ) --> 193 super().fit(args, **kwargs)

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:608, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 606 model = self._maybe_unwrap_optimized(model) 607 self.strategy._lightning_module = model --> 608 call._call_and_handle_interrupt( 609 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 610 )

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:38, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) 37 else: ---> 38 return trainer_fn(args, kwargs) 40 except _TunerExitException: 41 trainer._call_teardown_hook()

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:650, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 643 ckpt_path = ckpt_path or self.resume_from_checkpoint 644 self._ckpt_path = self._checkpoint_connector._set_ckpt_path( 645 self.state.fn, 646 ckpt_path, # type: ignore[arg-type] 647 model_provided=True, 648 model_connected=self.lightning_module is not None, 649 ) --> 650 self._run(model, ckpt_path=self.ckpt_path) 652 assert self.state.stopped 653 self.training = False

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1112, in Trainer._run(self, model, ckpt_path) 1108 self._checkpoint_connector.restore_training_state() 1110 self._checkpoint_connector.resume_end() -> 1112 results = self._run_stage() 1114 log.detail(f"{self.class.name}: trainer tearing down") 1115 self._teardown()

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1191, in Trainer._run_stage(self) 1189 if self.predicting: 1190 return self._run_predict() -> 1191 self._run_train()

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1214, in Trainer._run_train(self) 1211 self.fit_loop.trainer = self 1213 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1214 self.fit_loop.run()

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py:199, in Loop.run(self, *args, kwargs) 197 try: 198 self.on_advance_start(*args, *kwargs) --> 199 self.advance(args, kwargs) 200 self.on_advance_end() 201 self._restarting = False

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:267, in FitLoop.advance(self) 265 self._data_fetcher.setup(dataloader, batch_to_device=batch_to_device) 266 with self.trainer.profiler.profile("run_training_epoch"): --> 267 self._outputs = self.epoch_loop.run(self._data_fetcher)

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py:199, in Loop.run(self, *args, kwargs) 197 try: 198 self.on_advance_start(*args, *kwargs) --> 199 self.advance(args, kwargs) 200 self.on_advance_end() 201 self._restarting = False

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:213, in TrainingEpochLoop.advance(self, data_fetcher) 210 self.batch_progress.increment_started() 212 with self.trainer.profiler.profile("run_training_batch"): --> 213 batch_output = self.batch_loop.run(kwargs) 215 self.batch_progress.increment_processed() 217 # update non-plateau LR schedulers 218 # update epoch-interval ones only when we are at the end of training epoch

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py:199, in Loop.run(self, *args, kwargs) 197 try: 198 self.on_advance_start(*args, *kwargs) --> 199 self.advance(args, kwargs) 200 self.on_advance_end() 201 self._restarting = False

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py:90, in TrainingBatchLoop.advance(self, kwargs) 88 outputs = self.optimizer_loop.run(optimizers, kwargs) 89 else: ---> 90 outputs = self.manual_loop.run(kwargs) 91 if outputs: 92 # automatic: can be empty if all optimizers skip their batches 93 # manual: #9052 added support for raising StopIteration in the training_step. If that happens, 94 # then advance doesn't finish and an empty dict is returned 95 self._outputs.append(outputs)

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py:199, in Loop.run(self, *args, kwargs) 197 try: 198 self.on_advance_start(*args, *kwargs) --> 199 self.advance(args, kwargs) 200 self.on_advance_end() 201 self._restarting = False

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py:110, in ManualOptimization.advance(self, kwargs) 107 kwargs = self._build_kwargs(kwargs, self._hiddens) 109 # manually capture logged metrics --> 110 training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values()) 111 del kwargs # release the batch from memory 112 self.trainer.strategy.post_training_step()

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1494, in Trainer._call_strategy_hook(self, hook_name, *args, *kwargs) 1491 return 1493 with self.profiler.profile(f"[Strategy]{self.strategy.class.name}.{hook_name}"): -> 1494 output = fn(args, **kwargs) 1496 # restore current_fx when nested context 1497 pl_module._current_fx_name = prev_fx_name

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py:378, in Strategy.training_step(self, *args, *kwargs) 376 with self.precision_plugin.train_step_context(): 377 assert isinstance(self.model, TrainingStep) --> 378 return self.model.training_step(args, **kwargs)

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/scvi/train/_trainingplans.py:938, in PyroTrainingPlan.training_step(self, batch, batch_idx) 936 kwargs.update({"kl_weight": self.kl_weight}) 937 # pytorch lightning requires a Tensor object for loss --> 938 loss = torch.Tensor([self.svi.step(*args, **kwargs)]) 940 _opt = self.optimizers() 941 _opt.step()

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, *kwargs) 143 # get loss and compute gradients 144 with poutine.trace(param_only=True) as param_capture: --> 145 loss = self.loss_and_grads(self.model, self.guide, args, **kwargs) 147 params = set( 148 site["value"].unconstrained() for site in param_capture.trace.nodes.values() 149 ) 151 # actually perform gradient steps 152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs) 138 loss = 0.0 139 # grab a trace from the generator --> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): 141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle( 142 model_trace, guide_trace 143 ) 144 loss += loss_particle / self.num_particles

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs) 235 else: 236 for i in range(self.num_particles): --> 237 yield self._get_trace(model, guide, args, kwargs)

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs) 52 def _get_trace(self, model, guide, args, kwargs): 53 """ 54 Returns a single trace from the guide, and the model that is run 55 against it. 56 """ ---> 57 model_trace, guide_trace = get_importance_trace( 58 "flat", self.max_plate_nesting, model, guide, args, kwargs 59 ) 60 if is_validation_enabled(): 61 check_if_enumerated(guide_trace)

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach) 72 guide_trace = prune_subsample_sites(guide_trace) 73 model_trace = prune_subsample_sites(model_trace) ---> 75 model_trace.compute_log_prob() 76 guide_trace.compute_score_parts() 77 if is_validation_enabled():

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/poutine/trace_struct.py:236, in Trace.compute_log_prob(self, sitefilter) 234 , exc_value, traceback = sys.exc_info() 235 shapes = self.format_shapes(last_site=site["name"]) --> 236 raise ValueError( 237 "Error while computing log_prob at site '{}':\n{}\n{}".format( 238 name, exc_value, shapes 239 ) 240 ).with_traceback(traceback) from e 241 site["unscaled_log_prob"] = log_p 242 log_p = scale_and_mask(log_p, site["scale"], site["mask"])

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter) 228 if "log_prob" not in site: 229 try: --> 230 log_p = site["fn"].logprob( 231 site["value"], *site["args"], **site["kwargs"] 232 ) 233 except ValueError as e: 234 , exc_value, traceback = sys.exc_info()

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/pyro/distributions/conjugate.py:277, in GammaPoisson.log_prob(self, value) 275 def log_prob(self, value): 276 if self._validate_args: --> 277 self._validate_sample(value) 278 post_value = self.concentration + value 279 return ( 280 -log_beta(self.concentration, value + 1) 281 - post_value.log() 282 + self.concentration self.rate.log() 283 - post_value (1 + self.rate).log() 284 )

File ~/.conda/envs/pytorch2/lib/python3.8/site-packages/torch/distributions/distribution.py:300, in Distribution._validate_sample(self, value) 298 valid = support.check(value) 299 if not valid.all(): --> 300 raise ValueError( 301 "Expected value argument " 302 f"({type(value).name} of shape {tuple(value.shape)}) " 303 f"to be within the support ({repr(support)}) " 304 f"of the distribution {repr(self)}, " 305 f"but found invalid values:\n{value}" 306 )

ValueError: Error while computing log_prob at site 'data_target': Expected value argument (Tensor of shape (2500, 1776)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution GammaPoisson(), but found invalid values: tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 2.1972], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.6931, ..., 0.0000, 0.0000, 0.0000], [2.0794, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [3.2958, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.6931]], device='cuda:0') Trace Shapes:
Param Sites:
Sample Sites:
per_cluster_mu_fg dist | 23 1776 value | 23 1776 log_prob |
detection_mean_y_e dist | 1 1 value | 1 1 log_prob |
s_g_gene_add_alpha_hyp dist 1 1 |
value 1 1 |
log_prob 1 1 |
s_g_gene_add_mean dist | 1 1 value | 1 1 log_prob |
s_g_gene_add_alpha_e_inv dist | 1 1 value | 1 1 log_prob |
s_g_gene_add dist | 1 1776 value | 1 1776 log_prob |
alpha_g_phi_hyp dist 1 1 |
value 1 1 |
log_prob 1 1 |
alpha_g_inverse dist | 1 1776 value | 1 1776 log_prob |
data_target dist 2500 1776 |
value 2500 1776 | ''' I extract my raw data and there is not nan or inf value in the matrix. The size of the matrix is about 570000*1600. And I do not know how the data_target is calculated. How should I do? Any advice?

vitkl commented 1 year ago

See https://github.com/BayraktarLab/cell2location/issues/128 - you need integer counts data not normalised in any way.