First of all, I'd like to thank you for this useful tool. I think its protein-imputation capabilities from diverse reference datasets are very relevant. However, I have not been able to test it myself. After setting the sciPENN object, I get the error attached below at the sciPENN.train() step.
To add further context, I am using three publicly available CITEseq datasets as reference and an in-house scRNAseq (only gene expression) AnnData as query. Although I did not have time to deepen into it, I tried setting different batch_size values in the sciPENN_API() call, but that did not seem to solve the issue.
Do you have any ideas as to what might be causing it?
Thank you again for your help!
P.S. I performed per-cell count normalization (cell_normalize), log1p normalization (log_normalize), and highly variable gene selection (select_hvg) beforehand, so I set those parameters to False.
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/sciPENN/Network/Model.py:71, in sciPENN_Model.forward_simple(self, x)
70 def forward_simple(self, x):
---> 71 x = self.input_block(x)
72 h = self.RNNCell(x, zeros_like(x))
74 x = self.skip_1(x)
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:171, in _BatchNorm.forward(self, input)
164 bn_training = (self.running_mean is None) and (self.running_var is None)
166 r"""
167 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
168 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
169 used for normalization (i.e. in eval mode when buffers are not None).
170 """
--> 171 return F.batch_norm(
172 input,
173 # If buffers are not to be tracked, ensure that they won't be updated
174 self.running_mean
175 if not self.training or self.track_running_stats
176 else None,
177 self.running_var if not self.training or self.track_running_stats else None,
178 self.weight,
179 self.bias,
180 bn_training,
181 exponential_average_factor,
182 self.eps,
183 )
The issue was due to different highly variable genes in each dataset. Ended up using sciPENN_API() with default parameters after some QC filtering of the cells.
Hi there!
First of all, I'd like to thank you for this useful tool. I think its protein-imputation capabilities from diverse reference datasets are very relevant. However, I have not been able to test it myself. After setting the sciPENN object, I get the error attached below at the
sciPENN.train()
step.To add further context, I am using three publicly available CITEseq datasets as reference and an in-house scRNAseq (only gene expression) AnnData as query. Although I did not have time to deepen into it, I tried setting different
batch_size
values in thesciPENN_API()
call, but that did not seem to solve the issue.Do you have any ideas as to what might be causing it?
Thank you again for your help!
P.S. I performed per-cell count normalization (
cell_normalize
), log1p normalization (log_normalize
), and highly variable gene selection (select_hvg
) beforehand, so I set those parameters toFalse
.sciPENN = sciPENN_API(gene_trainsets = [govek_rna_concat, gayoso_rna_concat_111, gayoso_rna_concat_208, schroer_rna_concat], protein_trainsets = [govek_prot_concat, gayoso_prot_concat_111, gayoso_prot_concat_208, schroer_prot_concat], gene_test = adatas_concat, select_hvg=False, train_batchkeys = ['batch', 'batch', 'batch', 'batch'], test_batchkey = 'batch', cell_normalize=False, log_normalize=False, gene_normalize=True, use_gpu=False)
sciPENN.train(quantiles = [0.1, 0.25, 0.75, 0.9], n_epochs = 10000, ES_max = 12, decay_max = 6, decay_step = 0.1, lr = 10**(-3), weights_dir = "sciPENN", load = True)
IndexError Traceback (most recent call last) Cell In[68], line 1 ----> 1 sciPENN.train(quantiles = [0.1, 0.25, 0.75, 0.9], n_epochs = 10000, ES_max = 12, decay_max = 6, 2 decay_step = 0.1, lr = 10**(-3), weights_dir = "sciPENN", load = True)
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/sciPENN/sciPENN_API.py:83, in sciPENN_API.train(self, quantiles, n_epochs, ES_max, decay_max, decay_step, lr, weights_dir, load) 80 else: 81 train_params = (self.dataloaders['train'], self.dataloaders['val'], n_epochs, ES_max, decay_max, decay_step, lr) ---> 83 self.model.train_backprop(*train_params) 84 save(self.model.state_dict(), path)
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/sciPENN/Network/Model.py:105, in sciPENN_Model.train_backprop(self, train_loader, val_loader, n_epoch, ES_max, decay_max, decay_step, lr) 103 for batch, inputs in enumerate(val_loader): 104 mod1, mod2, protein_bools, celltypes = inputs --> 105 outputs = self(mod1) 107 n_correct = get_correct(outputs) 108 mod2_loss = self.loss2(outputs['modality 2'], mod2, protein_bools)
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/sciPENN/Network/Model.py:71, in sciPENN_Model.forward_simple(self, x) 70 def forward_simple(self, x): ---> 71 x = self.input_block(x) 72 h = self.RNNCell(x, zeros_like(x)) 74 x = self.skip_1(x)
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/sciPENN/Network/Layers.py:17, in Input_Block.forward(self, x_new) 16 def forward(self, x_new): ---> 17 x_new = self.bnorm_in(x_new) 18 x_new = self.dropout_in(x_new) 20 x = self.dense(x_new)
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:171, in _BatchNorm.forward(self, input) 164 bn_training = (self.running_mean is None) and (self.running_var is None) 166 r""" 167 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be 168 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are 169 used for normalization (i.e. in eval mode when buffers are not None). 170 """ --> 171 return F.batch_norm( 172 input, 173 # If buffers are not to be tracked, ensure that they won't be updated 174 self.running_mean 175 if not self.training or self.track_running_stats 176 else None, 177 self.running_var if not self.training or self.track_running_stats else None, 178 self.weight, 179 self.bias, 180 bn_training, 181 exponential_average_factor, 182 self.eps, 183 )
File ~/miniforge3/envs/single-cell/lib/python3.9/site-packages/torch/nn/functional.py:2450, in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps) 2447 if training: 2448 _verify_batch_size(input.size()) -> 2450 return torch.batch_norm( 2451 input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled 2452 )
IndexError: select(): index 0 out of range for tensor of size [0] at dimension 0