theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
331 stars 51 forks source link

Runtime error using scPoli - mat1 and mat2 must have the same dtype, but got Double and Float #215

Closed parkjosh-broadinstitute closed 11 months ago

parkjosh-broadinstitute commented 11 months ago

Hi, I am trying to follow this tutorial: https://github.com/theislab/scarches/blob/master/notebooks/scpoli_surgery_pipeline.ipynb

matplotlib==3.8.0 numpy==1.26.0 pandas==2.1.1 scArches==0.5.9 scanpy==1.9.5 scikit-learn==1.3.1 seaborn==0.13.0 torch==2.1.0

When I try to train the model I run into this error:

`RuntimeError Traceback (most recent call last) Cell In[17], line 8 1 scpoli_model = scPoli( 2 adata=ref, 3 condition_keys=condition_key, (...) 6 recon_loss='zinb', 7 ) ----> 8 scpoli_model.train( 9 n_epochs=125, 10 pretraining_epochs=100, 11 early_stopping_kwargs=early_stopping_kwargs, 12 eta=5, 13 )

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/models/scpoli/scpoli_model.py:304, in scPoli.train(self, n_epochs, pretraining_epochs, eta, lr, eps, alpha_epoch_anneal, reload_best, prototype_training, unlabeled_prototype_training, *kwargs) 287 pretraining_epochs = int(np.floor(n_epochs 0.9)) 290 self.trainer = scPoliTrainer( 291 self.model, 292 self.adata, (...) 302 **kwargs, 303 ) --> 304 self.trainer.train(n_epochs, lr, eps) 305 self.istrained = True 306 self.prototypeslabeled = self.model.prototypes_labeled

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:305, in scPoliTrainer.train(self, n_epochs, lr, eps) 302 batch_data[key] = batch.to(self.device) 304 #loss calculation --> 305 self.on_iteration(batch_data) 307 #validation of model, monitoring, early stopping 308 self.on_epoch_end()

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:333, in scPoliTrainer.on_iteration(self, batch_data) 330 module.track_running_stats = False 332 #calculate loss depending on trainer/model --> 333 self.current_loss = loss = self.loss(batch_data) 334 self.optimizer.zero_grad() 336 loss.backward()

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:533, in scPoliTrainer.loss(self, total_batch) 532 def loss(self, total_batch=None): --> 533 latent, recon_loss, kl_loss, mmd_loss = self.model(**total_batch) 535 #calculate classifier loss for labeled/unlabeled data 536 label_categories = total_batch["labeled"].unique().tolist()

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/models/scpoli/scpoli.py:142, in scpoli.forward(self, x, batch, combined_batch, sizefactor, celltypes, labeled) 140 x_log = x 141 if "encoder" in self.inject_condition: --> 142 z1_mean, z1_log_var = self.encoder(x_log, batch_embeddings) 143 else: 144 z1_mean, z1_log_var = self.encoder(x_log, batch=None)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/models/scpoli/scpoli.py:461, in Encoder.forward(self, x, batch) 459 x = torch.cat((x, batch), dim=-1) 460 if self.FC is not None: --> 461 x = self.FC(x) 462 means = self.mean_encoder(x) 463 log_vars = self.log_var_encoder(x)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input) 213 def forward(self, input): 214 for module in self: --> 215 input = module(input) 216 return input

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/models/scpoli/scpoli.py:665, in CondLayers.forward(self, x) 663 else: 664 expr, cond = torch.split(x, [x.shape[1] - self.n_cond, self.n_cond], dim=1) --> 665 out = self.expr_L(expr) + self.cond_L(cond) 666 return out

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input) 113 def forward(self, input: Tensor) -> Tensor: --> 114 return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 must have the same dtype, but got Double and Float`

Has anyone run into this error and been able to fix it?

cdedonno commented 11 months ago

Hi @parkjosh-broadinstitute, could you try to check if casting the input to 'float32' fixes the issue?

adata.X = adata.X.astype('float32') before passing it to the model

hernandezvargash commented 11 months ago

Hi, I get the exact same RunTimeError when running the scPoli function. I tried your suggestion: source_adata = source_adata.astype('float32') but then I get: AttributeError: 'AnnData' object has no attribute 'astype' My h5ad object was built from a Seurat object with the R function SeuratDisk::Convert. Thanks for your help

cdedonno commented 11 months ago

You need to transform the .X matrix within your adata object.

The correct command is: source_adata.X = source_adata.X.astype('float32').

hernandezvargash commented 11 months ago

oops, my mistake, this works for me and I managed to complete the remaining of the analysis without error. Thanks!

parkjosh-broadinstitute commented 11 months ago

Hi @cdedonno, yes that worked. thank you for your help!

parkjosh-broadinstitute commented 11 months ago

Even though "nb" and "zinb" loss expect to receive raw counts the data should still be cast to float32 using @cdedonno suggestion.

adata.X = adata.X.astype('float32')

cdedonno commented 11 months ago

Hi @parkjosh-broadinstitute, the transformation should work automatically, but apparently there might be issues when passing integers, I will check. Thanks for reporting this.