scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.21k stars 344 forks source link

total_vi.train() ValueError: Expected parameter loc (Tensor of shape (256, 20)) of distribution Normal(loc: torch.Size([256, 20]), scale: torch.Size([256, 20])) to satisfy the constraint Real() #2981

Open raozuming opened 23 hours ago

raozuming commented 23 hours ago

[TEXT HERE]

 scvi.model.TOTALVI.setup_mudata(
            mdata,
            rna_layer="counts" if rna_use_raw else None,
            protein_layer="counts" if protein_use_raw else None,
            modalities={
                "rna_layer": "multiomics" if self._use_hvg else "rna",
                "protein_layer": "protein",
            })

  total_vi = scvi.model.TOTALVI(mdata, **kwags)
  total_vi.train()
Traceback (most recent call last):
  File "../../saw_multianalysis/multiAnalysis.py", line 410, in <module>
    run()
  File "../../saw_multianalysis/multiAnalysis.py", line 405, in run
    main(rna_path, protein_path, bin_size, protein_list, out_dir, convert_py_bool(use_gpu), gpu, num_threads,
  File "../../saw_multianalysis/multiAnalysis.py", line 309, in main
    totalVI(rna_data, protein_data, prefix, proteins, out_dir, use_gpu, gpu, num_threads, report)
  File "../../saw_multianalysis/multiAnalysis.py", line 176, in totalVI
    total_vi = ms_data.tl.total_vi(
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/stereo/algorithm/total_vi.py", line 142, in main
    total_vi.train(use_gpu=use_gpu, **train_kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/model/_totalvi.py", line 299, in train
    return runner()
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/train/_trainrunner.py", line 82, in __call__
    self.trainer.fit(self.training_plan, self.data_splitter)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/train/_trainer.py", line 188, in fit
    super().fit(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
    self.fit_loop.run()
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 87, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 201, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 248, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 358, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1550, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/core/module.py", line 1705, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 216, in optimizer_step
    return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 153, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/optim/optimizer.py", line 391, in wrapper
    out = func(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/optim/adam.py", line 148, in step
    loss = closure()
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 138, in _wrap_closure
    closure_result = closure()
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 146, in __call__
    self._result = self.closure(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 132, in closure
    step_output = self._step_fn()
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 407, in _training_step
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1704, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 358, in training_step
    return self.model.training_step(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/train/_trainingplans.py", line 559, in training_step
    inference_outputs, _, scvi_loss = self.forward(
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/train/_trainingplans.py", line 282, in forward
    return self.module(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
    return fn(self, *args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/module/base/_base_module.py", line 276, in forward
    return _generic_forward(
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/module/base/_base_module.py", line 837, in _generic_forward
    inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
    return fn(self, *args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/module/_totalvae.py", line 491, in inference
    qz, ql, latent, untran_latent = self.encoder(
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/scvi/nn/_base_components.py", line 1029, in forward
    q_z = Normal(qz_m, qz_v.sqrt())
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/distributions/normal.py", line 56, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File ".conda/envs/multiAnalysis/lib/python3.8/site-packages/torch/distributions/distribution.py", line 80, in __init__
    raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (256, 20)) of distribution Normal(loc: torch.Size([256, 20]), scale: torch.Size([256, 20])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)
Epoch 1/400:   0%|   

Versions:

python 3.8 scvi-tools 0.19.0

VERSION

canergen commented 22 hours ago

Please try out our recent changes by installing scvi-tools from the main branch. I hope this is fixed there. Those will be released with scvi-tools 1.2

raozuming commented 22 hours ago

@canergen Thank you for your reply. Since scvi-tools 1.2 requires python > 3.8, I cannot upgrade the python version. Secondly, I tried to manually update the commit you fixed (https://github.com/scverse/scvi-tools/pull/2632/files) to _trainingplans.py, and the problem still occurs. image

mt1022 commented 16 hours ago

I am using python 3.12 and have the same problem when training a scvi.model.SCVI model. I tried to install the main branch or the 1.2.x branch and the same error was met.

cmd:

model.train(early_stopping=True, accelerator='cpu')    # same error with "mps"

error:

File [~/miniforge3/envs/sc/lib/python3.12/site-packages/torch/distributions/distribution.py:70](http://localhost:8890/lab/tree/~/miniforge3/envs/sc/lib/python3.12/site-packages/torch/distributions/distribution.py#line=69), in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     68         valid = constraint.check(value)
     69         if not valid.all():
---> 70             raise ValueError(
     71                 f"Expected parameter {param} "
     72                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     73                 f"of distribution {repr(self)} "
     74                 f"to satisfy the constraint {repr(constraint)}, "
     75                 f"but found invalid values:\n{value}"
     76             )
     77 super().__init__()

ValueError: Expected parameter loc (Tensor of shape (128, 10)) of distribution Normal(loc: torch.Size([128, 10]), scale: torch.Size([128, 10])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)

session info

-----
anndata     0.10.8
scanpy      1.10.3
-----
CoreFoundation              NA
Foundation                  NA
PIL                         10.4.0
PyObjCTools                 NA
absl                        NA
anyio                       NA
appnope                     0.1.4
arrow                       1.3.0
asttokens                   NA
attr                        24.2.0
attrs                       24.2.0
babel                       2.14.0
brotli                      1.1.0
certifi                     2024.08.30
cffi                        1.17.1
charset_normalizer          3.3.2
chex                        0.1.86
colorama                    0.4.6
comm                        0.2.2
contextlib2                 NA
cycler                      0.12.1
cython_runtime              NA
dateutil                    2.9.0
debugpy                     1.8.5
decorator                   5.1.1
defusedxml                  0.7.1
docrep                      0.3.2
etils                       1.9.4
executing                   2.1.0
fastjsonschema              NA
filelock                    3.16.1
flax                        0.9.0
fqdn                        NA
fsspec                      2024.9.0
gmpy2                       2.1.5
google                      NA
h5py                        3.11.0
idna                        3.10
importlib_resources         NA
ipykernel                   6.29.5
isoduration                 NA
jax                         0.4.31
jaxlib                      0.4.31
jedi                        0.19.1
jinja2                      3.1.4
joblib                      1.4.2
json5                       0.9.25
jsonpointer                 3.0.0
jsonschema                  4.23.0
jsonschema_specifications   NA
jupyter_events              0.10.0
jupyter_server              2.14.2
jupyterlab_server           2.27.3
kiwisolver                  1.4.7
legacy_api_wrap             NA
lightning                   2.4.0
lightning_utilities         0.11.7
llvmlite                    0.43.0
markupsafe                  2.1.5
matplotlib                  3.9.2
ml_collections              NA
ml_dtypes                   0.5.0
mpl_toolkits                NA
mpmath                      1.3.0
msgpack                     1.1.0
mudata                      0.3.1
multipledispatch            0.6.0
natsort                     8.4.0
nbformat                    5.10.4
numba                       0.60.0
numpy                       1.26.4
numpyro                     0.15.3
objc                        10.3.1
opt_einsum                  v3.3.0
optax                       0.2.2
overrides                   NA
packaging                   24.1
pandas                      2.2.2
parso                       0.8.4
patsy                       0.5.6
pickleshare                 0.7.5
platformdirs                4.3.6
prometheus_client           NA
prompt_toolkit              3.0.47
psutil                      6.0.0
pure_eval                   0.2.3
pycparser                   2.22
pydev_ipython               NA
pydevconsole                NA
pydevd                      2.9.5
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pygments                    2.18.0
pynndescent                 0.5.13
pyparsing                   3.1.4
pyro                        1.9.1+0a67ddc
pythonjsonlogger            NA
pytz                        2024.2
referencing                 NA
requests                    2.32.3
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rich                        NA
rpds                        NA
scipy                       1.14.1
scvi                        1.1.6
send2trash                  NA
session_info                1.0.0
six                         1.16.0
sklearn                     1.5.2
sniffio                     1.3.1
socks                       1.7.1
sparse                      0.15.4
stack_data                  0.6.2
statsmodels                 0.14.3
sympy                       1.13.2
threadpoolctl               3.5.0
toolz                       0.12.1
torch                       2.4.0
torchgen                    NA
torchmetrics                1.4.2
tornado                     6.4.1
tqdm                        4.66.5
traitlets                   5.14.3
typing_extensions           NA
umap                        0.5.6
uri_template                NA
urllib3                     2.2.3
wcwidth                     0.2.13
webcolors                   24.8.0
websocket                   1.8.0
xarray                      2024.9.0
yaml                        6.0.2
zmq                         26.2.0
zstandard                   0.23.0
-----
IPython             8.27.0
jupyter_client      8.6.3
jupyter_core        5.7.2
jupyterlab          4.2.5
-----
Python 3.12.6 | packaged by conda-forge | (main, Sep 11 2024, 04:55:15) [Clang 17.0.6 ]
macOS-14.6.1-arm64-arm-64bit
-----
Session information updated at 2024-09-19 22:09
canergen commented 16 hours ago

@raozuming We don’t support the use of outdated Python versions. There was an unfortunate stack of activation functions. It is safe to try several runs to get one that succeeds (like 10). If the issue persists, please set up a new environment with a supported Python version 3.10-3.12. @mt1022 It’s a different issue as it’s a different model. Please check out: https://docs.scvi-tools.org/en/latest/faq.html. It is usually a problem with low count cells that give bad gradients. Please share your dataset size, all parts of your code and verify that your AnnData contains counts.