BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
311 stars 57 forks source link

ValueError: Expected parameter loc (Parameter of shape (23, 1)) of distribution Normal(loc: torch.Size([23, 1]), scale: torch.Size([23, 1])) to satisfy the constraint Real(), but found invalid values: #384

Open chuangzhao0601 opened 1 week ago

chuangzhao0601 commented 1 week ago

Hi ! I ran the command mod.train(max_epochs=250, use_gpu=True) using the recommended dataset (V1_Human_Lymph_Node) on M3 MacBook Pro and I got this error:

[/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/train/_trainrunner.py:76](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/train/_trainrunner.py#line=75): UserWarning: `use_gpu` is deprecated in v1.0 and will be removed in v1.1. Please use `accelerator` and `devices` instead.
  accelerator, lightning_devices, device = parse_device_args(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:69](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py#line=68): UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
<frozen abc>:119: FutureWarning: SparseDataset is deprecated and will be removed in late 2024. It has been replaced by the public classes CSRDataset and CSCDataset.

For instance checks, use `isinstance(X, (anndata.experimental.CSRDataset, anndata.experimental.CSCDataset))` instead.

For creation, use `anndata.experimental.sparse_dataset(X)` instead.

[/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_struct.py:280](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_struct.py#line=279): UserWarning: Encountered NaN: log_prob_sum at site 'detection_tech_gene_tg'
  warn_if_nan(
[/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_struct.py:280](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_struct.py#line=279): UserWarning: Encountered NaN: log_prob_sum at site 'data_target'
  warn_if_nan(
[/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:158](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/trace_elbo.py#line=157): UserWarning: Encountered NaN: loss
  warn_if_nan(loss, "loss")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:191](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py#line=190), in TraceHandler.__call__(self, *args, **kwargs)
    190 try:
--> 191     ret = self.fn(*args, **kwargs)
    192 except (ValueError, RuntimeError) as e:

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/nn/module.py:520](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/nn/module.py#line=519), in PyroModule.__call__(self, *args, **kwargs)
    519 with self._pyro_context:
--> 520     result = super().__call__(*args, **kwargs)
    521 if (
    522     pyro.settings.get("validate_poutine")
    523     and not self._pyro_context.active
    524     and _is_module_local_param_enabled()
    525 ):

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/nn/modules/module.py:1553](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/nn/modules/module.py:1562](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py:525](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py#line=524), in AutoNormal.forward(self, *args, **kwargs)
    522 site_loc, site_scale = self._get_loc_and_scale(name)
    523 unconstrained_latent = pyro.sample(
    524     name + "_unconstrained",
--> 525     dist.Normal(
    526         site_loc,
    527         site_scale,
    528     ).to_event(self._event_dims[name]),
    529     infer={"is_auxiliary": True},
    530 )
    532 value = transform(unconstrained_latent)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/distributions/distribution.py:26](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/distributions/distribution.py#line=25), in DistributionMeta.__call__(cls, *args, **kwargs)
     25         return result
---> 26 return super().__call__(*args, **kwargs)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/distributions/normal.py:57](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/distributions/normal.py#line=56), in Normal.__init__(self, loc, scale, validate_args)
     56     batch_shape = self.loc.size()
---> 57 super().__init__(batch_shape, validate_args=validate_args)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/distributions/distribution.py:70](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/distributions/distribution.py#line=69), in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     69         if not valid.all():
---> 70             raise ValueError(
     71                 f"Expected parameter {param} "
     72                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     73                 f"of distribution {repr(self)} "
     74                 f"to satisfy the constraint {repr(constraint)}, "
     75                 f"but found invalid values:\n{value}"
     76             )
     77 super().__init__()

ValueError: Expected parameter loc (Parameter of shape (23, 1)) of distribution Normal(loc: torch.Size([23, 1]), scale: torch.Size([23, 1])) to satisfy the constraint Real(), but found invalid values:
Parameter containing:
tensor([[nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan]], device='mps:0', requires_grad=True)

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[10], line 3
      1 import torch
      2 if torch.backends.mps.is_available() and torch.backends.mps.is_built():
----> 3     mod.train(max_epochs=250, batch_size=5000, use_gpu=True)
      4 else:
      5     mod.train(max_epochs=250, batch_size=5000, use_gpu=False)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/cell2location/models/reference/_reference_model.py:157](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/cell2location/models/reference/_reference_model.py#line=156), in RegressionModel.train(self, max_epochs, batch_size, train_size, lr, **kwargs)
    154 kwargs["train_size"] = train_size
    155 kwargs["lr"] = lr
--> 157 super().train(**kwargs)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/model/base/_pyromixin.py:184](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/model/base/_pyromixin.py#line=183), in PyroSviTrainMixin.train(self, max_epochs, use_gpu, accelerator, device, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, lr, training_plan, plan_kwargs, **trainer_kwargs)
    172 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())
    174 runner = self._train_runner_cls(
    175     self,
    176     training_plan=training_plan,
   (...)
    182     **trainer_kwargs,
    183 )
--> 184 return runner()

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/train/_trainrunner.py:99](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/train/_trainrunner.py#line=98), in TrainRunner.__call__(self)
     96 if hasattr(self.data_splitter, "n_val"):
     97     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 99 self.trainer.fit(self.training_plan, self.data_splitter)
    100 self._update_history()
    102 # data splitter only gets these attrs after fit

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/train/_trainer.py:186](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/train/_trainer.py#line=185), in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:532](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py#line=531), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:43](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py#line=42), in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     41     if trainer.strategy.launcher is not None:
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 43     return trainer_fn(*args, **kwargs)
     45 except _TunerExitException:
     46     _call_teardown_hook(trainer)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:571](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py#line=570), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    561 self._data_connector.attach_data(
    562     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    563 )
    565 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    566     self.state.fn,
    567     ckpt_path,
    568     model_provided=True,
    569     model_connected=self.lightning_module is not None,
    570 )
--> 571 self._run(model, ckpt_path=ckpt_path)
    573 assert self.state.stopped
    574 self.training = False

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:980](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py#line=979), in Trainer._run(self, model, ckpt_path)
    975 self._signal_connector.register_signal_handlers()
    977 # ----------------------------
    978 # RUN THE TRAINER
    979 # ----------------------------
--> 980 results = self._run_stage()
    982 # ----------------------------
    983 # POST-Training CLEAN UP
    984 # ----------------------------
    985 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1023](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py#line=1022), in Trainer._run_stage(self)
   1021         self._run_sanity_check()
   1022     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1023         self.fit_loop.run()
   1024     return None
   1025 raise RuntimeError(f"Unexpected state {self.state}")

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:202](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py#line=201), in _FitLoop.run(self)
    200 try:
    201     self.on_advance_start()
--> 202     self.advance()
    203     self.on_advance_end()
    204     self._restarting = False

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:355](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py#line=354), in _FitLoop.advance(self)
    353 self._data_fetcher.setup(combined_loader)
    354 with self.trainer.profiler.profile("run_training_epoch"):
--> 355     self.epoch_loop.run(self._data_fetcher)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:133](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py#line=132), in _TrainingEpochLoop.run(self, data_fetcher)
    131 while not self.done:
    132     try:
--> 133         self.advance(data_fetcher)
    134         self.on_advance_end()
    135         self._restarting = False

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:221](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py#line=220), in _TrainingEpochLoop.advance(self, data_fetcher)
    219             batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
    220         else:
--> 221             batch_output = self.manual_optimization.run(kwargs)
    223 self.batch_progress.increment_processed()
    225 # update non-plateau LR schedulers
    226 # update epoch-interval ones only when we are at the end of training epoch

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py:91](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py#line=90), in _ManualOptimization.run(self, kwargs)
     89 self.on_run_start()
     90 with suppress(StopIteration):  # no loop to break at this level
---> 91     self.advance(kwargs)
     92 self._restarting = False
     93 return self.on_run_end()

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py:111](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py#line=110), in _ManualOptimization.advance(self, kwargs)
    108 trainer = self.trainer
    110 # manually capture logged metrics
--> 111 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    112 del kwargs  # release the batch from memory
    113 self.trainer.strategy.post_training_step()

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:294](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py#line=293), in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    291     return None
    293 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 294     output = fn(*args, **kwargs)
    296 # restore current_fx when nested context
    297 pl_module._current_fx_name = prev_fx_name

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:380](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py#line=379), in Strategy.training_step(self, *args, **kwargs)
    378 with self.precision_plugin.train_step_context():
    379     assert isinstance(self.model, TrainingStep)
--> 380     return self.model.training_step(*args, **kwargs)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/train/_trainingplans.py:1041](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/scvi/train/_trainingplans.py#line=1040), in PyroTrainingPlan.training_step(self, batch, batch_idx)
   1039     kwargs.update({"kl_weight": self.kl_weight})
   1040 # pytorch lightning requires a Tensor object for loss
-> 1041 loss = torch.Tensor([self.svi.step(*args, **kwargs)])
   1043 _opt = self.optimizers()
   1044 _opt.step()

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/svi.py:145](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/svi.py#line=144), in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:140](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/trace_elbo.py#line=139), in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    138 loss = 0.0
    139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142         model_trace, guide_trace
    143     )
    144     loss += loss_particle [/](http://localhost:8888/) self.num_particles

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/elbo.py:237](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/elbo.py#line=236), in ELBO._get_traces(self, model, guide, args, kwargs)
    235 else:
    236     for i in range(self.num_particles):
--> 237         yield self._get_trace(model, guide, args, kwargs)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:57](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/trace_elbo.py#line=56), in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/enum.py:60](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/enum.py#line=59), in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     58     model_trace, guide_trace = unwrapped_guide.get_traces()
     59 else:
---> 60     guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
     61         *args, **kwargs
     62     )
     63     if detach:
     64         guide_trace.detach_()

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:216](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py#line=215), in TraceHandler.get_trace(self, *args, **kwargs)
    208 def get_trace(self, *args, **kwargs) -> Trace:
    209     """
    210     :returns: data structure
    211     :rtype: pyro.poutine.Trace
   (...)
    214     Calls this poutine and returns its trace instead of the function's return value.
    215     """
--> 216     self(*args, **kwargs)
    217     return self.msngr.get_trace()

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:198](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py#line=197), in TraceHandler.__call__(self, *args, **kwargs)
    196         exc = exc_type("{}\n{}".format(exc_value, shapes))
    197         exc = exc.with_traceback(traceback)
--> 198         raise exc from e
    199     self.msngr.trace.add_node(
    200         "_RETURN", name="_RETURN", type="return", value=ret
    201     )
    202 return ret

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:191](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py#line=190), in TraceHandler.__call__(self, *args, **kwargs)
    187 self.msngr.trace.add_node(
    188     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
    189 )
    190 try:
--> 191     ret = self.fn(*args, **kwargs)
    192 except (ValueError, RuntimeError) as e:
    193     exc_type, exc_value, traceback = sys.exc_info()

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/nn/module.py:520](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/nn/module.py#line=519), in PyroModule.__call__(self, *args, **kwargs)
    518 def __call__(self, *args: Any, **kwargs: Any) -> Any:
    519     with self._pyro_context:
--> 520         result = super().__call__(*args, **kwargs)
    521     if (
    522         pyro.settings.get("validate_poutine")
    523         and not self._pyro_context.active
    524         and _is_module_local_param_enabled()
    525     ):
    526         self._check_module_local_param_usage()

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/nn/modules/module.py:1553](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/nn/modules/module.py:1562](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py:525](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py#line=524), in AutoNormal.forward(self, *args, **kwargs)
    520         stack.enter_context(plates[frame.name])
    522 site_loc, site_scale = self._get_loc_and_scale(name)
    523 unconstrained_latent = pyro.sample(
    524     name + "_unconstrained",
--> 525     dist.Normal(
    526         site_loc,
    527         site_scale,
    528     ).to_event(self._event_dims[name]),
    529     infer={"is_auxiliary": True},
    530 )
    532 value = transform(unconstrained_latent)
    533 if poutine.get_mask() is False:

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/distributions/distribution.py:26](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/pyro/distributions/distribution.py#line=25), in DistributionMeta.__call__(cls, *args, **kwargs)
     24     if result is not None:
     25         return result
---> 26 return super().__call__(*args, **kwargs)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/distributions/normal.py:57](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/distributions/normal.py#line=56), in Normal.__init__(self, loc, scale, validate_args)
     55 else:
     56     batch_shape = self.loc.size()
---> 57 super().__init__(batch_shape, validate_args=validate_args)

File [/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/distributions/distribution.py:70](http://localhost:8888/opt/anaconda3/envs/cell2loc/lib/python3.11/site-packages/torch/distributions/distribution.py#line=69), in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     68         valid = constraint.check(value)
     69         if not valid.all():
---> 70             raise ValueError(
     71                 f"Expected parameter {param} "
     72                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     73                 f"of distribution {repr(self)} "
     74                 f"to satisfy the constraint {repr(constraint)}, "
     75                 f"but found invalid values:\n{value}"
     76             )
     77 super().__init__()

ValueError: Expected parameter loc (Parameter of shape (23, 1)) of distribution Normal(loc: torch.Size([23, 1]), scale: torch.Size([23, 1])) to satisfy the constraint Real(), but found invalid values:
Parameter containing:
tensor([[nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan]], device='mps:0', requires_grad=True)
                             Trace Shapes:                      
                              Param Sites:                      
         AutoNormal.locs.per_cluster_mu_fg     34 10237         
       AutoNormal.scales.per_cluster_mu_fg     34 10237         
    AutoNormal.locs.detection_tech_gene_tg      2 10237         
  AutoNormal.scales.detection_tech_gene_tg      2 10237         
        AutoNormal.locs.detection_mean_y_e     23     1         
      AutoNormal.scales.detection_mean_y_e     23     1         
    AutoNormal.locs.s_g_gene_add_alpha_hyp      1     1         
  AutoNormal.scales.s_g_gene_add_alpha_hyp      1     1         
         AutoNormal.locs.s_g_gene_add_mean     23     1         
       AutoNormal.scales.s_g_gene_add_mean     23     1         
  AutoNormal.locs.s_g_gene_add_alpha_e_inv     23     1         
AutoNormal.scales.s_g_gene_add_alpha_e_inv     23     1         
                             Sample Sites:                      
                            obs_plate dist            |         
                                     value   5000     |         
      per_cluster_mu_fg_unconstrained dist            | 34 10237
                                     value            | 34 10237
                    per_cluster_mu_fg dist            | 34 10237
                                     value            | 34 10237
 detection_tech_gene_tg_unconstrained dist            |  2 10237
                                     value            |  2 10237
               detection_tech_gene_tg dist            |  2 10237
                                     value            |  2 10237
     detection_mean_y_e_unconstrained dist            | 23     1
                                     value            | 23     1
                   detection_mean_y_e dist            | 23     1
                                     value            | 23     1
 s_g_gene_add_alpha_hyp_unconstrained dist 1    1     |         
                                     value 1    1     |         
               s_g_gene_add_alpha_hyp dist 1    1     |         
                                     value 1    1     |         
      s_g_gene_add_mean_unconstrained dist            | 23     1
                                     value            | 23     1
                    s_g_gene_add_mean dist            | 23     1
                                     value            | 23     1

the important package versions: cell2location:0.1.4 torch:2.4.1 pytorch-lightning:2.4.0 scvi-tools:1.0.4

vitkl commented 1 day ago

Does this error happen immediately or after some time? If immediately, I would check that the provided data is not normalised. If after some time or the data is indeed not normalise then it could be numerical accuracy or package version issues with M3 Mac. I don't know what is the best environment recipe that works with GPU capability of M1/M2/M3 Macs.