theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
331 stars 51 forks source link

scPoli dtype error #196

Closed JesperGrud closed 1 year ago

JesperGrud commented 1 year ago

Hi,

Thanks for creating an interesting tool. I'm trying to apply scPoli to a large dataset with a lot of batches, running the following:

scpoli_model = scPoli( adata=adata_hvg, condition_keys='batch_indices', cell_type_keys='label', embedding_dims=5, recon_loss='nb', ) scpoli_model.train( n_epochs=50, pretraining_epochs=40, early_stopping_kwargs=early_stopping_kwargs, eta=5, )

This results in the following output and subsequent error message:

Embedding dictionary: Num conditions: [248] Embedding dim: [5] Encoder Architecture: Input Layer in, out and cond: 4000 64 5 Mean/Var Layer in/out: 64 10 Decoder Architecture: First Layer in, out and cond: 10 64 5 Output Layer in/out: 64 4000

Initializing dataloaders Starting training

RuntimeError Traceback (most recent call last) Cell In[25], line 8 1 scpoli_model = scPoli( 2 adata=adata_hvg, 3 condition_keys='batch_indices', (...) 6 recon_loss='nb', 7 ) ----> 8 scpoli_model.train( 9 n_epochs=50, 10 pretraining_epochs=40, 11 early_stopping_kwargs=early_stopping_kwargs, 12 eta=5, 13 )

File /opt/conda/lib/python3.10/site-packages/scarches/models/scpoli/scpoli_model.py:304, in scPoli.train(self, n_epochs, pretraining_epochs, eta, lr, eps, alpha_epoch_anneal, reload_best, prototype_training, unlabeled_prototype_training, *kwargs) 287 pretraining_epochs = int(np.floor(n_epochs 0.9)) 290 self.trainer = scPoliTrainer( 291 self.model, 292 self.adata, (...) 302 **kwargs, 303 ) --> 304 self.trainer.train(n_epochs, lr, eps) 305 self.istrained = True 306 self.prototypeslabeled = self.model.prototypes_labeled

File /opt/conda/lib/python3.10/site-packages/scarches/trainers/scpoli/trainer.py:305, in scPoliTrainer.train(self, n_epochs, lr, eps) 302 batch_data[key] = batch.to(self.device) 304 #loss calculation --> 305 self.on_iteration(batch_data) 307 #validation of model, monitoring, early stopping 308 self.on_epoch_end()

File /opt/conda/lib/python3.10/site-packages/scarches/trainers/scpoli/trainer.py:333, in scPoliTrainer.on_iteration(self, batch_data) 330 module.track_running_stats = False 332 #calculate loss depending on trainer/model --> 333 self.current_loss = loss = self.loss(batch_data) 334 self.optimizer.zero_grad() 336 loss.backward()

File /opt/conda/lib/python3.10/site-packages/scarches/trainers/scpoli/trainer.py:533, in scPoliTrainer.loss(self, total_batch) 532 def loss(self, total_batch=None): --> 533 latent, recon_loss, kl_loss, mmd_loss = self.model(**total_batch) 535 #calculate classifier loss for labeled/unlabeled data 536 label_categories = total_batch["labeled"].unique().tolist()

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/scarches/models/scpoli/scpoli.py:142, in scpoli.forward(self, x, batch, combined_batch, sizefactor, celltypes, labeled) 140 x_log = x 141 if "encoder" in self.inject_condition: --> 142 z1_mean, z1_log_var = self.encoder(x_log, batch_embeddings) 143 else: 144 z1_mean, z1_log_var = self.encoder(x_log, batch=None)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/scarches/models/scpoli/scpoli.py:461, in Encoder.forward(self, x, batch) 459 x = torch.cat((x, batch), dim=-1) 460 if self.FC is not None: --> 461 x = self.FC(x) 462 means = self.mean_encoder(x) 463 log_vars = self.log_var_encoder(x)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input) 215 def forward(self, input): 216 for module in self: --> 217 input = module(input) 218 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/scarches/models/scpoli/scpoli.py:665, in CondLayers.forward(self, x) 663 else: 664 expr, cond = torch.split(x, [x.shape[1] - self.n_cond, self.n_cond], dim=1) --> 665 out = self.expr_L(expr) + self.cond_L(cond) 666 return out

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input) 113 def forward(self, input: Tensor) -> Tensor: --> 114 return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 must have the same dtype

As a 'control' I tried to plug the same dataset into scVI within scArches using the following:

sca.models.SCVI.setup_anndata(adata_hvg, batch_key='batch_indices', labels_key='label') vae = sca.models.SCVI(adata_hvg, n_layers=2, encode_covariates=True, deeply_inject_covariates=False, use_layer_norm="both", use_batch_norm="none") vae.train()

This runs smoothly. Any ideas? in adata_hvg.X I have raw counts and have tried both as a sparse matrix and as a non sparse matrix. The adata_hvg object looks like this:

AnnData object with n_obs × n_vars = 323500 × 4000 obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'mitochondrial_fraction', 'ribosomal_fraction', 'coding_fraction', 'label', 'Gender', 'Cov1', 'Cov2', 'Cov3', 'Cov4', 'Cov5', 'Cov6', 'Cov7', 'Source', 'batch', 'batch_indices' var: 'features', 'means', 'variances', 'residual_variances', 'highly_variable_rank', 'highly_variable_nbatches', 'highly_variable_intersection', 'highly_variable' uns: 'hvg' layers: 'raw'

Any ideas?

Thanks in advance

cdedonno commented 1 year ago

Hey @JesperGrud, the code looks correct to me. I am wondering: are your batch_indices stored with a numerical data type in your adata.obs? If that's the case, could you check if casting them to a categorical type using adata_hvg.obs['batch_indices'] = adata_hvg.obs['batch_indices'].astype('category') fixes this bug?

Or if your batch .obs column is a categorical equivalent to batch_indices, try passing that as condition_keys to the model.

JesperGrud commented 1 year ago

Thank you for the quick reply. The 'batch_indices' is already a categorical. Here:

cell_641 1 cell_642 1 cell_643 1 cell_644 1 cell_645 1 ... cell733069 151 cell733484 151 cell733812 151 cell734019 151 cell734576 151 Name: batch_indices, Length: 323500, dtype: category Categories (228, int64): [0, 1, 2, 3, ..., 224, 225, 226, 227].

None the less, I did try doing the following: adata_hvg.obs['label'] = adata_hvg.obs['label'].astype('category') adata_hvg.obs['batch'] = adata_hvg.obs['batch'].astype('category')

I do get the same error after doing that.

cdedonno commented 1 year ago

Hm, strange. Can you show me the output of adata_hvg.X.dtype? I suspect it might be different than the standard float32. In that case, you could try casting your input using adata_hvg.X = adata_hvg.X.astype('float32').

JesperGrud commented 1 year ago

You're correct, it was a float64. Thanks a lot. Casting to float32 solved the issue, so I'm closing the issue.

Thanks again!