theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
331 stars 51 forks source link

fail to using scpoli for dataset training #205

Closed melancholy12 closed 1 year ago

melancholy12 commented 1 year ago

scpoli_model.train(n_epochs=200, pretraining_epochs=40, early_stopping_kwargs=early_stopping_kwargs, eta=5, )

Initializing dataloaders Starting training

ValueError Traceback (most recent call last) Cell In[26], line 1 ----> 1 scpoli_model.train(n_epochs=200, 2 pretraining_epochs=40, 3 early_stopping_kwargs=early_stopping_kwargs, 4 eta=5, 5 )

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/scarches/models/scpoli/scpoli_model.py:304, in scPoli.train(self, n_epochs, pretraining_epochs, eta, lr, eps, alpha_epoch_anneal, reload_best, prototype_training, unlabeled_prototype_training, *kwargs) 287 pretraining_epochs = int(np.floor(n_epochs 0.9)) 290 self.trainer = scPoliTrainer( 291 self.model, 292 self.adata, (...) 302 **kwargs, 303 ) --> 304 self.trainer.train(n_epochs, lr, eps) 305 self.istrained = True 306 self.prototypeslabeled = self.model.prototypes_labeled

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:305, in scPoliTrainer.train(self, n_epochs, lr, eps) 302 batch_data[key] = batch.to(self.device) 304 #loss calculation --> 305 self.on_iteration(batch_data) 307 #validation of model, monitoring, early stopping 308 self.on_epoch_end()

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:333, in scPoliTrainer.on_iteration(self, batch_data) 330 module.track_running_stats = False 332 #calculate loss depending on trainer/model --> 333 self.current_loss = loss = self.loss(batch_data) 334 self.optimizer.zero_grad() 336 loss.backward()

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:533, in scPoliTrainer.loss(self, total_batch) 532 def loss(self, total_batch=None): --> 533 latent, recon_loss, kl_loss, mmd_loss = self.model(**total_batch) 535 #calculate classifier loss for labeled/unlabeled data 536 label_categories = total_batch["labeled"].unique().tolist()

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/scarches/models/scpoli/scpoli.py:145, in scpoli.forward(self, x, batch, combined_batch, sizefactor, celltypes, labeled) 143 else: 144 z1_mean, z1_log_var = self.encoder(x_log, batch=None) --> 145 z1 = self.sampling(z1_mean, z1_log_var) 147 if "decoder" in self.inject_condition: 148 outputs = self.decoder(z1, batch_embeddings)

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/scarches/models/scpoli/scpoli.py:331, in scpoli.sampling(self, mu, log_var) 318 """Samples from standard Normal distribution and applies re-parametrization trick. 319 It is actually sampling from latent space distributions with N(mu, var), computed by encoder. 320 Parameters (...) 328 Torch Tensor of sampled data. 329 """ 330 var = torch.exp(log_var) + 1e-4 --> 331 return Normal(mu, var.sqrt()).rsample()

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/torch/distributions/normal.py:56, in Normal.init(self, loc, scale, validate_args) 54 else: 55 batch_shape = self.loc.size() ---> 56 super().init(batch_shape, validate_args=validate_args)

File ~/anaconda3/envs/siyuan4/lib/python3.9/site-packages/torch/distributions/distribution.py:62, in Distribution.init(self, batch_shape, event_shape, validate_args) 60 valid = constraint.check(value) 61 if not valid.all(): ---> 62 raise ValueError( 63 f"Expected parameter {param} " 64 f"({type(value).name} of shape {tuple(value.shape)}) " 65 f"of distribution {repr(self)} " 66 f"to satisfy the constraint {repr(constraint)}, " 67 f"but found invalid values:\n{value}" 68 ) 69 super().init()

ValueError: Expected parameter loc (Tensor of shape (128, 10)) of distribution Normal(loc: torch.Size([128, 10]), scale: torch.Size([128, 10])) to satisfy the constraint Real(), but found invalid values: tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], ..., [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]], grad_fn=)

cdedonno commented 1 year ago

Could you provide some details about the dataset you are using, and the versions of the package?

cdedonno commented 1 year ago

Closing because of no response