Closed sokratiag closed 1 year ago
Hi,
OK, I managed to find the error - the priors should be defined as floats and not as integers.
m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1.0, "alpha_mean": 3.0} is OK
while m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1, "alpha_mean": 3} is not apparently and throws the above error.
Thanks.
Hi,
For the following trained model
create and train the model
mod = cell2location.models.Cell2location( adata_vis, cell_state_df=inf_aver,
model_class=LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelNoMGPyroModel,
) mod.view_anndata_setup()
I getting an error when trying to export the posterior in mod.export_posterior (running on HPC)
Sampling local variables, batch: 0%| | 0/1 [00:00<?, ?it/s]
RuntimeError Traceback (most recent call last) File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e:
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, *kwargs) 448 with self._pyro_context: --> 449 result = super().call(args, **kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ):
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, *kwargs) 1517 else: -> 1518 return self._call_impl(args, **kwargs)
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try:
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/cell2location/models/_cell2location_module.py:251, in LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel.forward(self, x_data, idx, batch_index) 249 # =====================Gene expression level scaling m_g======================= # 250 # Explains difference in sensitivity for each gene between single cell and spatial technology --> 251 m_g_mean = pyro.sample( 252 "m_g_mean", 253 dist.Gamma( 254 self.m_g_mu_mean_var_ratio_hyp * self.m_g_mu_hyp, 255 self.m_g_mu_mean_var_ratio_hyp, 256 ) 257 .expand([1, 1]) 258 .to_event(2), 259 ) # (1, 1) 261 m_g_alpha_e_inv = pyro.sample( 262 "m_g_alpha_e_inv", 263 dist.Exponential(self.m_g_alpha_hyp_mean).expand([1, 1]).to_event(2), 264 ) # (1, 1)
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/primitives.py:163, in sample(name, fn, *args, **kwargs) 162 # apply the stack and return its return value --> 163 apply_stack(msg) 164 return msg["value"]
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/runtime.py:217, in apply_stack(initial_msg) 215 break --> 217 default_process_message(msg) 219 for frame in stack[-pointer:]:
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/runtime.py:179, in default_process_message(msg) 177 return msg --> 179 msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) 181 # after fn has been called, update msg to prevent it from being called again.
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/distributions/torch_distribution.py:48, in TorchDistributionMixin.call(self, sample_shape) 31 """ 32 Samples a random value. 33 (...) 45 :rtype: torch.Tensor 46 """ 47 return ( ---> 48 self.rsample(sample_shape) 49 if self.has_rsample 50 else self.sample(sample_shape) 51 )
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/independent.py:104, in Independent.rsample(self, sample_shape) 103 def rsample(self, sample_shape=torch.Size()): --> 104 return self.base_dist.rsample(sample_shape)
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/gamma.py:71, in Gamma.rsample(self, sample_shape) 70 shape = self._extended_shape(sample_shape) ---> 71 value = _standardgamma(self.concentration.expand(shape)) / self.rate.expand( 72 shape 73 ) 74 value.detach().clamp( 75 min=torch.finfo(value.dtype).tiny 76 ) # do not record in autograd graph
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/gamma.py:12, in _standard_gamma(concentration) 11 def _standard_gamma(concentration): ---> 12 return torch._standard_gamma(concentration)
RuntimeError: "gamma_cuda" not implemented for 'Long'
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last) Cell In[64], line 2 1 # In this section, we export the estimated cell abundance (summary of the posterior distribution). ----> 2 adata_vis = mod.export_posterior( 3 adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': True} 4 ) 6 # Save model 7 # mod.save(f"{run_name}", overwrite=True) 8 (...) 13 # adata_vis.write(adata_file) 14 # adata_file
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/cell2location/models/_cell2location_model.py:352, in Cell2location.export_posterior(self, adata, sample_kwargs, export_slot, add_to_obsm, use_quantiles) 348 self.samples[f"postsample{i}"] = self.posterior_quantile(q=q, sample_kwargs) 349 else: 350 # generate samples from posterior distributions for all parameters 351 # and compute mean, 5%/95% quantiles and standard deviation --> 352 self.samples = self.sample_posterior(sample_kwargs) 354 # export posterior distribution summary for all parameters and 355 # annotation (model, date, var, obs and cell type names) to anndata object 356 adata.uns[export_slot] = self._export2adata(self.samples)
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:529, in PyroSampleMixin.sample_posterior(self, num_samples, return_sites, use_gpu, accelerator, device, batch_size, return_observed, return_samples, summary_fun) 483 """Summarise posterior distribution. 484 485 Generate samples from posterior distribution for each parameter (...) 526 to keep all model-specific variables in one place. 527 """ 528 # sample using minibatches (if full data, data is moved to GPU only once anyway) --> 529 samples = self._posterior_samples_minibatch( 530 use_gpu=use_gpu, 531 accelerator=accelerator, 532 device=device, 533 batch_size=batch_size, 534 num_samples=num_samples, 535 return_sites=return_sites, 536 return_observed=return_observed, 537 ) 539 param_names = list(samples.keys()) 540 results = {}
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:415, in PyroSampleMixin._posterior_samples_minibatch(self, use_gpu, accelerator, device, batch_size, **sample_kwargs) 413 if i == 0: 414 return_observed = getattr(sample_kwargs, "return_observed", False) --> 415 obs_plate_sites = self._get_obs_plate_sites( 416 args, kwargs, return_observed=return_observed 417 ) 418 if len(obs_plate_sites) == 0: 419 # if no local variables - don't sample 420 break
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:339, in PyroSampleMixin._get_obs_plate_sites(self, args, kwargs, return_observed) 336 plate_name = self.module.list_obs_plate_vars["name"] 338 # find plate dimension --> 339 trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) 340 obs_plate = { 341 name: site["cond_indep_stack"][0].dim 342 for name, site in trace.nodes.items() (...) 355 if any(f.name == plate_name for f in site["cond_indep_stack"]) 356 } 358 return obs_plate
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, kwargs) 190 def get_trace(self, *args, *kwargs): 191 """ 192 :returns: data structure 193 :rtype: pyro.poutine.Trace (...) 196 Calls this poutine and returns its trace instead of the function's return value. 197 """ --> 198 self(args, kwargs) 199 return self.msngr.get_trace()
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.call(self, *args, **kwargs) 178 exc = exc_type("{}\n{}".format(exc_value, shapes)) 179 exc = exc.with_traceback(traceback) --> 180 raise exc from e 181 self.msngr.trace.add_node( 182 "_RETURN", name="_RETURN", type="return", value=ret 183 ) 184 return ret
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 170 self.msngr.trace.add_node( 171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs 172 ) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e: 176 exc_type, exc_value, traceback = sys.exc_info()
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/cell2location/models/_cell2location_module.py:251, in LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel.forward(self, x_data, idx, batch_index) 247 obs_plate = self.create_plates(x_data, idx, batch_index) 249 # =====================Gene expression level scaling m_g======================= # 250 # Explains difference in sensitivity for each gene between single cell and spatial technology --> 251 m_g_mean = pyro.sample( 252 "m_g_mean", 253 dist.Gamma( 254 self.m_g_mu_mean_var_ratio_hyp * self.m_g_mu_hyp, 255 self.m_g_mu_mean_var_ratio_hyp, 256 ) 257 .expand([1, 1]) 258 .to_event(2), 259 ) # (1, 1) 261 m_g_alpha_e_inv = pyro.sample( 262 "m_g_alpha_e_inv", 263 dist.Exponential(self.m_g_alpha_hyp_mean).expand([1, 1]).to_event(2), 264 ) # (1, 1) 265 m_g_alpha_e = self.ones / m_g_alpha_e_inv.pow(2)
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/primitives.py:163, in sample(name, fn, *args, **kwargs) 146 msg = { 147 "type": "sample", 148 "name": name, (...) 160 "continuation": None, 161 } 162 # apply the stack and return its return value --> 163 apply_stack(msg) 164 return msg["value"]
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/runtime.py:217, in apply_stack(initial_msg) 214 if msg["stop"]: 215 break --> 217 default_process_message(msg) 219 for frame in stack[-pointer:]: 220 frame._postprocess_message(msg)
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/poutine/runtime.py:179, in default_process_message(msg) 176 msg["done"] = True 177 return msg --> 179 msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) 181 # after fn has been called, update msg to prevent it from being called again. 182 msg["done"] = True
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/pyro/distributions/torch_distribution.py:48, in TorchDistributionMixin.call(self, sample_shape) 30 def call(self, sample_shape=torch.Size()): 31 """ 32 Samples a random value. 33 (...) 45 :rtype: torch.Tensor 46 """ 47 return ( ---> 48 self.rsample(sample_shape) 49 if self.has_rsample 50 else self.sample(sample_shape) 51 )
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/independent.py:104, in Independent.rsample(self, sample_shape) 103 def rsample(self, sample_shape=torch.Size()): --> 104 return self.base_dist.rsample(sample_shape)
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/gamma.py:71, in Gamma.rsample(self, sample_shape) 69 def rsample(self, sample_shape=torch.Size()): 70 shape = self._extended_shape(sample_shape) ---> 71 value = _standardgamma(self.concentration.expand(shape)) / self.rate.expand( 72 shape 73 ) 74 value.detach().clamp( 75 min=torch.finfo(value.dtype).tiny 76 ) # do not record in autograd graph 77 return value
File ~/.conda/envs/cell2loc_env_new3/lib/python3.9/site-packages/torch/distributions/gamma.py:12, in _standard_gamma(concentration) 11 def _standard_gamma(concentration): ---> 12 return torch._standard_gamma(concentration)
RuntimeError: "gamma_cuda" not implemented for 'Long' Trace Shapes:
Param Sites:
Sample Sites:
obs_plate dist | value 1609 |
I'd appreciate your help.