scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.24k stars 350 forks source link

ValueError during vae.train() #1281

Closed 321356766 closed 2 years ago

321356766 commented 2 years ago

I am receiving an error for vae.train() when using totalVI. Runs fine using the GEX only for scVI. Changing the latent_distribution to "ln" during totalvi model setup seems to bypass the error but cite-seq/scRNA-seq results do not look good.

Any insight into what could be causing this error would be much appreciated. Similar datasets (different cell donors) run fine.

Thanks


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Epoch 10/400:   2%|▏         | 9/400 [00:30<21:43,  3.33s/it, loss=1.13e+03, v_num=1]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_21647/2537747719.py in <module>
----> 1 vae.train()

~/anaconda3/lib/python3.8/site-packages/scvi/model/_totalvi.py in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, early_stopping, check_val_every_n_epoch, reduce_lr_on_plateau, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_classifier, plan_kwargs, **kwargs)
    284             **kwargs,
    285         )
--> 286         return runner()
    287 
    288     @torch.no_grad()

~/anaconda3/lib/python3.8/site-packages/scvi/train/_trainrunner.py in __call__(self)
     70             self.training_plan.n_obs_training = self.data_splitter.n_train
     71 
---> 72         self.trainer.fit(self.training_plan, self.data_splitter)
     73         self._update_history()
     74 

~/anaconda3/lib/python3.8/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
    175                     message="`LightningModule.configure_optimizers` returned `None`",
    176                 )
--> 177             super().fit(*args, **kwargs)

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    458         )
    459 
--> 460         self._run(model)
    461 
    462         assert self.state.stopped

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    756 
    757         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 758         self.dispatch()
    759 
    760         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    797             self.accelerator.start_predicting(self)
    798         else:
--> 799             self.accelerator.start_training(self)
    800 
    801     def run_stage(self):

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    142     def start_training(self, trainer: 'pl.Trainer') -> None:
    143         # double dispatch to initiate the training loop
--> 144         self._results = trainer.run_stage()
    145 
    146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    807         if self.predicting:
    808             return self.run_predict()
--> 809         return self.run_train()
    810 
    811     def _pre_training_routine(self):

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    869                 with self.profiler.profile("run_training_epoch"):
    870                     # run train epoch
--> 871                     self.train_loop.run_training_epoch()
    872 
    873                 if self.max_steps and self.max_steps <= self.global_step:

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    497             # ------------------------------------
    498             with self.trainer.profiler.profile("run_training_batch"):
--> 499                 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    500 
    501             # when returning -1 from train_step, we end epoch early

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
    736 
    737                         # optimizer step
--> 738                         self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    739                         if len(self.trainer.optimizers) > 1:
    740                             # revert back to previous state

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    432 
    433         # model hook
--> 434         model_ref.optimizer_step(
    435             self.trainer.current_epoch,
    436             batch_idx,

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1401 
   1402         """
-> 1403         optimizer.step(closure=optimizer_closure)
   1404 
   1405     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
    212             profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
    213 
--> 214         self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215         self._total_optimizer_step_calls += 1
    216 

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
    132 
    133         with trainer.profiler.profile(profiler_name):
--> 134             trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
    135 
    136     def step(self, *args, closure: Optional[Callable] = None, **kwargs):

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    327         )
    328         if make_optimizer_step:
--> 329             self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    330         self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    331         self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    334         self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
    335     ) -> None:
--> 336         self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
    337 
    338     def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
    191 
    192     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 193         optimizer.step(closure=lambda_closure, **kwargs)
    194 
    195     @property

~/anaconda3/lib/python3.8/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
     86                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     87                 with torch.autograd.profiler.record_function(profile_name):
---> 88                     return func(*args, **kwargs)
     89             return wrapper
     90 

~/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

~/anaconda3/lib/python3.8/site-packages/torch/optim/adam.py in step(self, closure)
     90         if closure is not None:
     91             with torch.enable_grad():
---> 92                 loss = closure()
     93 
     94         for group in self.param_groups:

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
    730 
    731                         def train_step_and_backward_closure():
--> 732                             result = self.training_step_and_backward(
    733                                 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    734                             )

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    821         with self.trainer.profiler.profile("training_step_and_backward"):
    822             # lightning module hook
--> 823             result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    824             self._curr_step_result = result
    825 

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    288             model_ref._results = Result()
    289             with self.trainer.profiler.profile("training_step"):
--> 290                 training_step_output = self.trainer.accelerator.training_step(args)
    291                 self.trainer.accelerator.post_training_step()
    292 

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
    202 
    203         with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 204             return self.training_type_plugin.training_step(*args)
    205 
    206     def post_training_step(self) -> None:

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
    153 
    154     def training_step(self, *args, **kwargs):
--> 155         return self.lightning_module.training_step(*args, **kwargs)
    156 
    157     def post_training_step(self):

~/anaconda3/lib/python3.8/site-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
    361         if optimizer_idx == 0:
    362             loss_kwargs = dict(kl_weight=self.kl_weight)
--> 363             inference_outputs, _, scvi_loss = self.forward(
    364                 batch, loss_kwargs=loss_kwargs
    365             )

~/anaconda3/lib/python3.8/site-packages/scvi/train/_trainingplans.py in forward(self, *args, **kwargs)
    145     def forward(self, *args, **kwargs):
    146         """Passthrough to `model.forward()`."""
--> 147         return self.module(*args, **kwargs)
    148 
    149     def training_step(self, batch, batch_idx, optimizer_idx=0):

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if self.training:
---> 32             return fn(self, *args, **kwargs)
     33 
     34         device = list(set(p.device for p in self.parameters()))

~/anaconda3/lib/python3.8/site-packages/scvi/module/base/_base_module.py in forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    143             tensors, **get_inference_input_kwargs
    144         )
--> 145         inference_outputs = self.inference(**inference_inputs, **inference_kwargs)
    146         generative_inputs = self._get_generative_input(
    147             tensors, inference_outputs, **get_generative_input_kwargs

~/anaconda3/lib/python3.8/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if self.training:
---> 32             return fn(self, *args, **kwargs)
     33 
     34         device = list(set(p.device for p in self.parameters()))

~/anaconda3/lib/python3.8/site-packages/scvi/module/_totalvae.py in inference(self, x, y, batch_index, label, n_samples, cont_covs, cat_covs)
    465         else:
    466             categorical_input = tuple()
--> 467         qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder(
    468             encoder_input, batch_index, *categorical_input
    469         )

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/scvi/nn/_base_components.py in forward(self, data, *cat_list)
    992         qz_m = self.z_mean_encoder(q)
    993         qz_v = torch.exp(self.z_var_encoder(q)) + 1e-4
--> 994         z, untran_z = self.reparameterize_transformation(qz_m, qz_v)
    995 
    996         ql_gene = self.l_gene_encoder(data, *cat_list)

~/anaconda3/lib/python3.8/site-packages/scvi/nn/_base_components.py in reparameterize_transformation(self, mu, var)
    958 
    959     def reparameterize_transformation(self, mu, var):
--> 960         untran_z = Normal(mu, var.sqrt()).rsample()
    961         z = self.z_transformation(untran_z)
    962         return z, untran_z

~/anaconda3/lib/python3.8/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
     48         else:
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 
     52     def expand(self, batch_shape, _instance=None):

~/anaconda3/lib/python3.8/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     53                 valid = constraint.check(value)
     54                 if not valid.all():
---> 55                     raise ValueError(
     56                         f"Expected parameter {param} "
     57                         f"({type(value).__name__} of shape {tuple(value.shape)}) "

ValueError: Expected parameter loc (Tensor of shape (256, 20)) of distribution Normal(loc: torch.Size([256, 20]), scale: torch.Size([256, 20])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<AddmmBackward0>)```

#### Versions:
<!-- Output of scvi.__version__ -->
> 0.14.5
adamgayoso commented 2 years ago

Can you provide more details on the dataset that fails? How many genes/cells? Is this integration of scRNA and CITE seq such that you're adding "missing" measurements (represented as 0s) for the scRNA data?

adamgayoso commented 2 years ago

Here are some things you can try in the meantime:

  1. Set empirical_protein_background_prior=False on init of TOTALVI, this can be mispecified if there are missing proteins
  2. Turn down the learning rate during training vae.train(lr=2e-3) for example
  3. If using latent_distribution="ln", use the metric="hellinger" or metric="correlation" for neighbors graph for better visualization
321356766 commented 2 years ago

Thanks for your response. The dataset is around 50000 cells and I am trying to create a joint representation of 4000 HVGs with 150+ ADTs. Interestingly, larger datasets than this work flawlessly. It is one coherent dataset-- I am not trying to integrate two datasets, one with missing measurements. I will try your suggests and report back. Many thanks!

adamgayoso commented 2 years ago

I see. You may also check to see if there are not any (cell, protein) values in the protein matrix that are extremely large. This can be an issue with CITE-seq.

adamgayoso commented 2 years ago

@321356766 did you make any progress on this?

adamgayoso commented 2 years ago

Closing due to inactivity