BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
324 stars 58 forks source link

ValueError: Expected parameter rate (Tensor of shape (1, 1)) of distribution Gamma(concentration: tensor([[10.]], device='cuda:0') #343

Closed RolantusdataExp closed 10 months ago

RolantusdataExp commented 10 months ago

Please use the template below to post a question to https://discourse.scverse.org/c/ecosytem/cell2location/.

Problem

Dear Cell2Location team, I am a newbie in python, and now run into the following error, when running following:

mod_st.train(max_epochs=30000,

train using full data (batch_size=None)

      batch_size=None,
      # use all data points in training because
      # we need to estimate cell abundance at all locations
      train_size=1,
      accelerator = "gpu" 
     )

The error: ValueError: Expected parameter rate (Tensor of shape (1, 1)) of distribution Gamma(concentration: tensor([[10.]], device='cuda:0'), rate: tensor([[nan]], device='cuda:0')) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values: tensor([[nan]], device='cuda:0')

Is it a problem concerning my GPU or a setting?

It is data of a human organ

Single cell reference data: number of cells, number of cell types, number of genes

Number of cells: 136000 Number of cell types: 29 Number of genes: +10k

Spatial data: number of locations numbers, technology type (e.g. Visium, ISS, Nanostring WTA)

Number of lacations 600 Visium

Best regards, Peter

RolantusdataExp commented 10 months ago

The whole error output:

GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /home/ucloud/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:69: You passed in a val_dataloader but have no validation_step. Skipping val loop. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1] /home/ucloud/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:281: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

ValueError Traceback (most recent call last) File ~/.local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e:

File ~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, *kwargs) 11 with context: ---> 12 return fn(args, **kwargs)

File ~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, *kwargs) 11 with context: ---> 12 return fn(args, **kwargs)

File ~/.local/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, *kwargs) 448 with self._pyro_context: --> 449 result = super().call(args, **kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ):

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510, in Module._wrapped_call_impl(self, *args, *kwargs) 1509 else: -> 1510 return self._call_impl(args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519, in Module._call_impl(self, *args, *kwargs) 1516 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1517 or _global_backward_pre_hooks or _global_backward_hooks 1518 or _global_forward_hooks or _global_forward_pre_hooks): -> 1519 return forward_call(args, **kwargs) 1521 try:

File ~/.local/lib/python3.10/site-packages/cell2location/models/_cell2location_module.py:386, in LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel.forward(self, x_data, idx, batch_index) 382 # =====================Location-specific detection efficiency ======================= # 383 # y_s with hierarchical mean prior 384 detection_mean_y_e = pyro.sample( 385 "detection_mean_y_e", --> 386 dist.Gamma( 387 self.ones self.detection_mean_hyp_prior_alpha, 388 self.ones self.detection_mean_hyp_prior_beta, 389 ) 390 .expand([self.n_batch, 1]) 391 .to_event(2), 392 ) 393 detection_hyp_prior_alpha = pyro.deterministic( 394 "detection_hyp_prior_alpha", 395 self.ones_n_batch_1 * self.detection_hyp_prior_alpha, 396 )

File ~/.local/lib/python3.10/site-packages/pyro/distributions/distribution.py:24, in DistributionMeta.call(cls, *args, *kwargs) 23 return result ---> 24 return super().call(args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/distributions/gamma.py:58, in Gamma.init(self, concentration, rate, validate_args) 57 batch_shape = self.concentration.size() ---> 58 super().init(batch_shape, validate_args=validate_args)

File /usr/local/lib/python3.10/dist-packages/torch/distributions/distribution.py:68, in Distribution.init(self, batch_shape, event_shape, validate_args) 67 if not valid.all(): ---> 68 raise ValueError( 69 f"Expected parameter {param} " 70 f"({type(value).name} of shape {tuple(value.shape)}) " 71 f"of distribution {repr(self)} " 72 f"to satisfy the constraint {repr(constraint)}, " 73 f"but found invalid values:\n{value}" 74 ) 75 super().init()

ValueError: Expected parameter rate (Tensor of shape (1, 1)) of distribution Gamma(concentration: tensor([[10.]], device='cuda:0'), rate: tensor([[nan]], device='cuda:0')) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values: tensor([[nan]], device='cuda:0')

The above exception was the direct cause of the following exception:

ValueError Traceback (most recent call last) Cell In[44], line 1 ----> 1 mod_st.train(max_epochs=30000, 2 # train using full data (batch_size=None) 3 batch_size=None, 4 # use all data points in training because 5 # we need to estimate cell abundance at all locations 6 train_size=1, 7 accelerator = "gpu" 8 )

File ~/.local/lib/python3.10/site-packages/cell2location/models/_cell2location_model.py:209, in Cell2location.train(self, max_epochs, batch_size, train_size, lr, num_particles, scale_elbo, *kwargs) 206 scale_elbo = 1.0 / (self.summary_stats["n_cells"] self.summary_stats["n_vars"]) 207 kwargs["plan_kwargs"]["scale_elbo"] = scale_elbo --> 209 super().train(**kwargs)

File ~/.local/lib/python3.10/site-packages/scvi/model/base/_pyromixin.py:184, in PyroSviTrainMixin.train(self, max_epochs, use_gpu, accelerator, device, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, lr, training_plan, plan_kwargs, trainer_kwargs) 172 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) 174 runner = self._train_runner_cls( 175 self, 176 training_plan=training_plan, (...) 182 trainer_kwargs, 183 ) --> 184 return runner()

File ~/.local/lib/python3.10/site-packages/scvi/train/_trainrunner.py:99, in TrainRunner.call(self) 96 if hasattr(self.data_splitter, "n_val"): 97 self.training_plan.n_obs_validation = self.data_splitter.n_val ---> 99 self.trainer.fit(self.training_plan, self.data_splitter) 100 self._update_history() 102 # data splitter only gets these attrs after fit

File ~/.local/lib/python3.10/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, *kwargs) 180 if isinstance(args[0], PyroTrainingPlan): 181 warnings.filterwarnings( 182 action="ignore", 183 category=UserWarning, 184 message="LightningModule.configure_optimizers returned None", 185 ) --> 186 super().fit(args, **kwargs)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 530 self.strategy._lightning_module = model 531 _verify_strategy_supports_compile(model, self.strategy) --> 532 call._call_and_handle_interrupt( 533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 534 )

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 41 if trainer.strategy.launcher is not None: 42 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) ---> 43 return trainer_fn(args, kwargs) 45 except _TunerExitException: 46 _call_teardown_hook(trainer)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 561 self._data_connector.attach_data( 562 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule 563 ) 565 ckpt_path = self._checkpoint_connector._select_ckpt_path( 566 self.state.fn, 567 ckpt_path, 568 model_provided=True, 569 model_connected=self.lightning_module is not None, 570 ) --> 571 self._run(model, ckpt_path=ckpt_path) 573 assert self.state.stopped 574 self.training = False

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:980, in Trainer._run(self, model, ckpt_path) 975 self._signal_connector.register_signal_handlers() 977 # ---------------------------- 978 # RUN THE TRAINER 979 # ---------------------------- --> 980 results = self._run_stage() 982 # ---------------------------- 983 # POST-Training CLEAN UP 984 # ---------------------------- 985 log.debug(f"{self.class.name}: trainer tearing down")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1023, in Trainer._run_stage(self) 1021 self._run_sanity_check() 1022 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1023 self.fit_loop.run() 1024 return None 1025 raise RuntimeError(f"Unexpected state {self.state}")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:198, in _FitLoop.run(self) 196 return 197 self.reset() --> 198 self.on_run_start() 199 while not self.done: 200 try:

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:316, in _FitLoop.on_run_start(self) 312 self._data_fetcher = _select_data_fetcher(trainer) 314 self._results.to(device=trainer.lightning_module.device) --> 316 call._call_callback_hooks(trainer, "on_train_start") 317 call._call_lightning_module_hook(trainer, "on_train_start") 318 call._call_strategy_hook(trainer, "on_train_start")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:195, in _call_callback_hooks(trainer, hook_name, monitoring_callbacks, *args, *kwargs) 193 if callable(fn): 194 with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): --> 195 fn(trainer, trainer.lightning_module, args, **kwargs) 197 if pl_module: 198 # restore current_fx when nested context 199 pl_module._current_fx_name = prev_fx_name

File ~/.local/lib/python3.10/site-packages/scvi/model/base/_pyromixin.py:44, in PyroJitGuideWarmup.on_train_start(self, trainer, pl_module) 42 tens = {k: t.to(pl_module.device) for k, t in tensors.items()} 43 args, kwargs = pl_module.module._get_fn_args_from_batch(tens) ---> 44 pyro_guide(*args, **kwargs) 45 break

File ~/.local/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510, in Module._wrapped_call_impl(self, *args, kwargs) 1508 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1509 else: -> 1510 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519, in Module._call_impl(self, *args, *kwargs) 1514 # If we don't have any hooks, we want to skip the rest of the logic in 1515 # this function, and just call forward. 1516 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1517 or _global_backward_pre_hooks or _global_backward_hooks 1518 or _global_forward_hooks or _global_forward_pre_hooks): -> 1519 return forward_call(args, **kwargs) 1521 try: 1522 result = None

File ~/.local/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:510, in AutoNormal.forward(self, *args, kwargs) 508 # if we've never run the model before, do so now so we can inspect the model structure 509 if self.prototype_trace is None: --> 510 self._setup_prototype(*args, *kwargs) 512 plates = self._create_plates(args, kwargs) 513 result = {}

File ~/.local/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:460, in AutoNormal._setup_prototype(self, *args, kwargs) 459 def _setup_prototype(self, *args, *kwargs): --> 460 super()._setup_prototype(args, kwargs) 462 self._event_dims = {} 463 self.locs = PyroModule()

File ~/.local/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:157, in AutoGuide._setup_prototype(self, *args, kwargs) 154 def _setup_prototype(self, *args, *kwargs): 155 # run the model so we can inspect its structure 156 model = poutine.block(self.model, self._prototype_hide_fn) --> 157 self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( 158 args, kwargs 159 ) 160 if self.master is not None: 161 self.master()._check_prototype(self.prototype_trace)

File ~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/.local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, kwargs) 190 def get_trace(self, *args, *kwargs): 191 """ 192 :returns: data structure 193 :rtype: pyro.poutine.Trace (...) 196 Calls this poutine and returns its trace instead of the function's return value. 197 """ --> 198 self(args, kwargs) 199 return self.msngr.get_trace()

File ~/.local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.call(self, *args, **kwargs) 178 exc = exc_type("{}\n{}".format(exc_value, shapes)) 179 exc = exc.with_traceback(traceback) --> 180 raise exc from e 181 self.msngr.trace.add_node( 182 "_RETURN", name="_RETURN", type="return", value=ret 183 ) 184 return ret

File ~/.local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 170 self.msngr.trace.add_node( 171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs 172 ) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e: 176 exc_type, exc_value, traceback = sys.exc_info()

File ~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/.local/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510, in Module._wrapped_call_impl(self, *args, kwargs) 1508 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1509 else: -> 1510 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519, in Module._call_impl(self, *args, *kwargs) 1514 # If we don't have any hooks, we want to skip the rest of the logic in 1515 # this function, and just call forward. 1516 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1517 or _global_backward_pre_hooks or _global_backward_hooks 1518 or _global_forward_hooks or _global_forward_pre_hooks): -> 1519 return forward_call(args, **kwargs) 1521 try: 1522 result = None

File ~/.local/lib/python3.10/site-packages/cell2location/models/_cell2location_module.py:386, in LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel.forward(self, x_data, idx, batch_index) 373 pyro.sample( 374 k + "_initial", 375 dist.Gamma( (...) 379 obs=w_sf, 380 ) # (self.n_obs, self.n_factors) 382 # =====================Location-specific detection efficiency ======================= # 383 # y_s with hierarchical mean prior 384 detection_mean_y_e = pyro.sample( 385 "detection_mean_y_e", --> 386 dist.Gamma( 387 self.ones self.detection_mean_hyp_prior_alpha, 388 self.ones self.detection_mean_hyp_prior_beta, 389 ) 390 .expand([self.n_batch, 1]) 391 .to_event(2), 392 ) 393 detection_hyp_prior_alpha = pyro.deterministic( 394 "detection_hyp_prior_alpha", 395 self.ones_n_batch_1 * self.detection_hyp_prior_alpha, 396 ) 398 beta = (obs2sample @ detection_hyp_prior_alpha) / (obs2sample @ detection_mean_y_e)

File ~/.local/lib/python3.10/site-packages/pyro/distributions/distribution.py:24, in DistributionMeta.call(cls, *args, *kwargs) 22 if result is not None: 23 return result ---> 24 return super().call(args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/distributions/gamma.py:58, in Gamma.init(self, concentration, rate, validate_args) 56 else: 57 batch_shape = self.concentration.size() ---> 58 super().init(batch_shape, validate_args=validate_args)

File /usr/local/lib/python3.10/dist-packages/torch/distributions/distribution.py:68, in Distribution.init(self, batch_shape, event_shape, validate_args) 66 valid = constraint.check(value) 67 if not valid.all(): ---> 68 raise ValueError( 69 f"Expected parameter {param} " 70 f"({type(value).name} of shape {tuple(value.shape)}) " 71 f"of distribution {repr(self)} " 72 f"to satisfy the constraint {repr(constraint)}, " 73 f"but found invalid values:\n{value}" 74 ) 75 super().init()

ValueError: Expected parameter rate (Tensor of shape (1, 1)) of distribution Gamma(concentration: tensor([[10.]], device='cuda:0'), rate: tensor([[nan]], device='cuda:0')) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values: tensor([[nan]], device='cuda:0') Trace Shapes:
Param Sites:
Sample Sites:
m_g_mean dist | 1 1 value | 1 1 m_g_alpha_e_inv dist | 1 1 value | 1 1 m_g dist | 1 0 value | 1 0 n_s_cells_per_location dist 626 1 |
value 626 1 |
b_s_groups_per_location dist 626 1 |
value 626 1 |
z_sr_groups_factors dist 626 50 |
value 626 50 |
k_r_factors_per_groups dist | 50 1 value | 50 1 x_fr_group2fact dist | 50 24 value | 50 24 w_sf dist 626 24 |
value 626 24 |

vitkl commented 10 months ago

Please check that the data you are attempting to use (registered with setup_anndata) is indeed integer counts and that you don't have cells with exactly 0 total counts.

RolantusdataExp commented 10 months ago

Problem solved, thanks for the help

rue1996 commented 10 months ago

@RolantusdataExp how did you solve this issue?

RolantusdataExp commented 10 months ago

Hi @rue1996 - I realised that my signature matrix and Spatial matrix used different gene annotations causing the intercept step to return an empty matrix

capuddddd commented 1 month ago

Hi @rue1996 - I realised that my signature matrix and Spatial matrix used different gene annotations causing the intercept step to return an empty matrix

@RolantusdataExp i got the same error, bu i am confused with "different gene annotations", did u mean "genome version" or just "n_vars" of AnnData?

vitkl commented 1 month ago

I would make sure that after filtering by shared var_names you don’t have empty anndata and DataFrame.

On Wed, 9 Oct 2024 at 11:54, capuddddd @.***> wrote:

Hi @rue1996 https://github.com/rue1996 - I realised that my signature matrix and Spatial matrix used different gene annotations causing the intercept step to return an empty matrix

@RolantusdataExp https://github.com/RolantusdataExp i got the same error, bu i am confused with "different gene annotations", did u mean "genome version" or just "n_vars" of AnnData?

— Reply to this email directly, view it on GitHub https://github.com/BayraktarLab/cell2location/issues/343#issuecomment-2401992280, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFMFTV7EK7WTB44K3AI6HVDZ2UDPFAVCNFSM6AAAAABPUH4X6WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBRHE4TEMRYGA . You are receiving this because you commented.Message ID: @.***>

capuddddd commented 1 month ago

I would make sure that after filtering by shared var_names you don’t have empty anndata and DataFrame. On Wed, 9 Oct 2024 at 11:54, capuddddd @.> wrote: Hi @rue1996 https://github.com/rue1996 - I realised that my signature matrix and Spatial matrix used different gene annotations causing the intercept step to return an empty matrix @RolantusdataExp https://github.com/RolantusdataExp i got the same error, bu i am confused with "different gene annotations", did u mean "genome version" or just "n_vars" of AnnData? — Reply to this email directly, view it on GitHub <#343 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFMFTV7EK7WTB44K3AI6HVDZ2UDPFAVCNFSM6AAAAABPUH4X6WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBRHE4TEMRYGA . You are receiving this because you commented.Message ID: @.>

thank you for ur explanation! i have fixed this problem