BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
321 stars 58 forks source link

error in export posterior #329

Closed sokratiag closed 1 year ago

sokratiag commented 1 year ago

Hi,

For the following trained model

create and train the model

mod = cell2location.models.Cell2location( adata_vis, cell_state_df=inf_aver,

model_class=LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelNoMGPyroModel,

# the expected average cell abundance: tissue-dependent
# hyper-prior which can be estimated from paired histology:
m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1, "alpha_mean": 3},
N_cells_per_location=8,
A_factors_per_location=5.0,
B_groups_per_location=3.0,
N_cells_mean_var_ratio = 1.0,

# hyperparameter controlling normalisation of
# within-experiment variation in RNA detection:
detection_alpha=20

) mod.view_anndata_setup()

I getting an error when trying to export the posterior in mod.export_posterior (running on HPC)

Sampling local variables, batch: 0%| | 0/1 [00:00<?, ?it/s]

RuntimeError Traceback (most recent call last) File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e:

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, *kwargs) 448 with self._pyro_context: --> 449 result = super().call(args, **kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ):

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, *kwargs) 1517 else: -> 1518 return self._call_impl(args, **kwargs)

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try:

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/cell2location/models/_cell2location_module.py:251, in LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel.forward(self, x_data, idx, batch_index) 249 # =====================Gene expression level scaling m_g======================= # 250 # Explains difference in sensitivity for each gene between single cell and spatial technology --> 251 m_g_mean = pyro.sample( 252 "m_g_mean", 253 dist.Gamma( 254 self.m_g_mu_mean_var_ratio_hyp * self.m_g_mu_hyp, 255 self.m_g_mu_mean_var_ratio_hyp, 256 ) 257 .expand([1, 1]) 258 .to_event(2), 259 ) # (1, 1) 261 m_g_alpha_e_inv = pyro.sample( 262 "m_g_alpha_e_inv", 263 dist.Exponential(self.m_g_alpha_hyp_mean).expand([1, 1]).to_event(2), 264 ) # (1, 1)

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/primitives.py:163, in sample(name, fn, *args, **kwargs) 162 # apply the stack and return its return value --> 163 apply_stack(msg) 164 return msg["value"]

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/runtime.py:217, in apply_stack(initial_msg) 215 break --> 217 default_process_message(msg) 219 for frame in stack[-pointer:]:

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/runtime.py:179, in default_process_message(msg) 177 return msg --> 179 msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) 181 # after fn has been called, update msg to prevent it from being called again.

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/distributions/torch_distribution.py:48, in TorchDistributionMixin.call(self, sample_shape) 31 """ 32 Samples a random value. 33 (...) 45 :rtype: torch.Tensor 46 """ 47 return ( ---> 48 self.rsample(sample_shape) 49 if self.has_rsample 50 else self.sample(sample_shape) 51 )

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/independent.py:104, in Independent.rsample(self, sample_shape) 103 def rsample(self, sample_shape=torch.Size()): --> 104 return self.base_dist.rsample(sample_shape)

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/gamma.py:71, in Gamma.rsample(self, sample_shape) 70 shape = self._extended_shape(sample_shape) ---> 71 value = _standardgamma(self.concentration.expand(shape)) / self.rate.expand( 72 shape 73 ) 74 value.detach().clamp( 75 min=torch.finfo(value.dtype).tiny 76 ) # do not record in autograd graph

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/gamma.py:12, in _standard_gamma(concentration) 11 def _standard_gamma(concentration): ---> 12 return torch._standard_gamma(concentration)

RuntimeError: "gamma_cuda" not implemented for 'Long'

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last) Cell In[64], line 2 1 # In this section, we export the estimated cell abundance (summary of the posterior distribution). ----> 2 adata_vis = mod.export_posterior( 3 adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': True} 4 ) 6 # Save model 7 # mod.save(f"{run_name}", overwrite=True) 8 (...) 13 # adata_vis.write(adata_file) 14 # adata_file

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/cell2location/models/_cell2location_model.py:352, in Cell2location.export_posterior(self, adata, sample_kwargs, export_slot, add_to_obsm, use_quantiles) 348 self.samples[f"postsample{i}"] = self.posterior_quantile(q=q, sample_kwargs) 349 else: 350 # generate samples from posterior distributions for all parameters 351 # and compute mean, 5%/95% quantiles and standard deviation --> 352 self.samples = self.sample_posterior(sample_kwargs) 354 # export posterior distribution summary for all parameters and 355 # annotation (model, date, var, obs and cell type names) to anndata object 356 adata.uns[export_slot] = self._export2adata(self.samples)

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:529, in PyroSampleMixin.sample_posterior(self, num_samples, return_sites, use_gpu, accelerator, device, batch_size, return_observed, return_samples, summary_fun) 483 """Summarise posterior distribution. 484 485 Generate samples from posterior distribution for each parameter (...) 526 to keep all model-specific variables in one place. 527 """ 528 # sample using minibatches (if full data, data is moved to GPU only once anyway) --> 529 samples = self._posterior_samples_minibatch( 530 use_gpu=use_gpu, 531 accelerator=accelerator, 532 device=device, 533 batch_size=batch_size, 534 num_samples=num_samples, 535 return_sites=return_sites, 536 return_observed=return_observed, 537 ) 539 param_names = list(samples.keys()) 540 results = {}

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:415, in PyroSampleMixin._posterior_samples_minibatch(self, use_gpu, accelerator, device, batch_size, **sample_kwargs) 413 if i == 0: 414 return_observed = getattr(sample_kwargs, "return_observed", False) --> 415 obs_plate_sites = self._get_obs_plate_sites( 416 args, kwargs, return_observed=return_observed 417 ) 418 if len(obs_plate_sites) == 0: 419 # if no local variables - don't sample 420 break

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:339, in PyroSampleMixin._get_obs_plate_sites(self, args, kwargs, return_observed) 336 plate_name = self.module.list_obs_plate_vars["name"] 338 # find plate dimension --> 339 trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) 340 obs_plate = { 341 name: site["cond_indep_stack"][0].dim 342 for name, site in trace.nodes.items() (...) 355 if any(f.name == plate_name for f in site["cond_indep_stack"]) 356 } 358 return obs_plate

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, kwargs) 190 def get_trace(self, *args, *kwargs): 191 """ 192 :returns: data structure 193 :rtype: pyro.poutine.Trace (...) 196 Calls this poutine and returns its trace instead of the function's return value. 197 """ --> 198 self(args, kwargs) 199 return self.msngr.get_trace()

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.call(self, *args, **kwargs) 178 exc = exc_type("{}\n{}".format(exc_value, shapes)) 179 exc = exc.with_traceback(traceback) --> 180 raise exc from e 181 self.msngr.trace.add_node( 182 "_RETURN", name="_RETURN", type="return", value=ret 183 ) 184 return ret

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 170 self.msngr.trace.add_node( 171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs 172 ) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e: 176 exc_type, exc_value, traceback = sys.exc_info()

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/cell2location/models/_cell2location_module.py:251, in LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel.forward(self, x_data, idx, batch_index) 247 obs_plate = self.create_plates(x_data, idx, batch_index) 249 # =====================Gene expression level scaling m_g======================= # 250 # Explains difference in sensitivity for each gene between single cell and spatial technology --> 251 m_g_mean = pyro.sample( 252 "m_g_mean", 253 dist.Gamma( 254 self.m_g_mu_mean_var_ratio_hyp * self.m_g_mu_hyp, 255 self.m_g_mu_mean_var_ratio_hyp, 256 ) 257 .expand([1, 1]) 258 .to_event(2), 259 ) # (1, 1) 261 m_g_alpha_e_inv = pyro.sample( 262 "m_g_alpha_e_inv", 263 dist.Exponential(self.m_g_alpha_hyp_mean).expand([1, 1]).to_event(2), 264 ) # (1, 1) 265 m_g_alpha_e = self.ones / m_g_alpha_e_inv.pow(2)

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/primitives.py:163, in sample(name, fn, *args, **kwargs) 146 msg = { 147 "type": "sample", 148 "name": name, (...) 160 "continuation": None, 161 } 162 # apply the stack and return its return value --> 163 apply_stack(msg) 164 return msg["value"]

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/runtime.py:217, in apply_stack(initial_msg) 214 if msg["stop"]: 215 break --> 217 default_process_message(msg) 219 for frame in stack[-pointer:]: 220 frame._postprocess_message(msg)

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/runtime.py:179, in default_process_message(msg) 176 msg["done"] = True 177 return msg --> 179 msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) 181 # after fn has been called, update msg to prevent it from being called again. 182 msg["done"] = True

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/distributions/torch_distribution.py:48, in TorchDistributionMixin.call(self, sample_shape) 30 def call(self, sample_shape=torch.Size()): 31 """ 32 Samples a random value. 33 (...) 45 :rtype: torch.Tensor 46 """ 47 return ( ---> 48 self.rsample(sample_shape) 49 if self.has_rsample 50 else self.sample(sample_shape) 51 )

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/independent.py:104, in Independent.rsample(self, sample_shape) 103 def rsample(self, sample_shape=torch.Size()): --> 104 return self.base_dist.rsample(sample_shape)

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/gamma.py:71, in Gamma.rsample(self, sample_shape) 69 def rsample(self, sample_shape=torch.Size()): 70 shape = self._extended_shape(sample_shape) ---> 71 value = _standardgamma(self.concentration.expand(shape)) / self.rate.expand( 72 shape 73 ) 74 value.detach().clamp( 75 min=torch.finfo(value.dtype).tiny 76 ) # do not record in autograd graph 77 return value

File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/gamma.py:12, in _standard_gamma(concentration) 11 def _standard_gamma(concentration): ---> 12 return torch._standard_gamma(concentration)

RuntimeError: "gamma_cuda" not implemented for 'Long' Trace Shapes:
Param Sites:
Sample Sites:
obs_plate dist | value 1609 |

I'd appreciate your help.

sokratiag commented 1 year ago

Hi,

OK, I managed to find the error - the priors should be defined as floats and not as integers.

m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1.0, "alpha_mean": 3.0} is OK

while m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1, "alpha_mean": 3} is not apparently and throws the above error.

Thanks.