LiQian-XC / sctour

A deep learning architecture for robust inference and accurate prediction of cellular dynamics
https://sctour.readthedocs.io
MIT License
51 stars 4 forks source link

tnode.train() Error #9

Closed obrien-james closed 4 months ago

obrien-james commented 4 months ago

Hi,

When I run the following code:

tnode = sct.train.Trainer(adata, use_gpu=False)
tnode.train()
adata.obs['ptime'] = tnode.get_time()
mix_zs, zs, pred_zs = tnode.get_latentsp(alpha_z=0.7, alpha_predz=0.3)
adata.obsm['X_TNODE'] = mix_zs
adata.obsm['X_VF'] = tnode.get_vector_field(adata.obs['ptime'].values, mix_zs)

I'm getting the following error, it seems to be failing on the tnode.train()

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], [line 2](vscode-notebook-cell:?execution_count=7&line=2)
      [1](vscode-notebook-cell:?execution_count=7&line=1) tnode = sct.train.Trainer(adata, use_gpu=False)
----> [2](vscode-notebook-cell:?execution_count=7&line=2) tnode.train()
      [3](vscode-notebook-cell:?execution_count=7&line=3) adata.obs['ptime'] = tnode.get_time()
      [4](vscode-notebook-cell:?execution_count=7&line=4) mix_zs, zs, pred_zs = tnode.get_latentsp(alpha_z=0.7, alpha_predz=0.3)

File [~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:265](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:265), in Trainer.train(self)
    [263](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:263) with tqdm(total=self.nepoch, unit='epoch') as t:
    [264](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:264)     for tepoch in range(t.total):
--> [265](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:265)         train_loss = self._on_epoch_train(self.train_dl)
    [266](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:266)         val_loss = self._on_epoch_val(self.val_dl)
    [267](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:267)         self.log['train_loss'].append(train_loss)

File [~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:296](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:296), in Trainer._on_epoch_train(self, DL)
    [294](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:294) X = X.to(self.device)
    [295](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:295) Y = Y.to(self.device)
--> [296](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:296) loss, recon_loss_ec, recon_loss_ode, kl_div, z_div = self.model(X, Y)
    [297](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:297) loss.backward()
    [298](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/sctour/train.py:298) self.optimizer.step()

File [~/miniconda3/envs/parse/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/torch/nn/modules/module.py:1517) else:
...
File [~/miniconda3/envs/parse/lib/python3.10/site-packages/torch/nn/modules/linear.py:114](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/torch/nn/modules/linear.py:114), in Linear.forward(self, input)
    [113](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/torch/nn/modules/linear.py:113) def forward(self, input: Tensor) -> Tensor:
--> [114](https://file+.vscode-resource.vscode-cdn.net/Users/james/Documents/Work/Research%20Job/Computational/TEST/analysis/20240321/filtered/~/miniconda3/envs/parse/lib/python3.10/site-packages/torch/nn/modules/linear.py:114)     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 must have the same dtype, but got Double and Float

I'm not sure how I can debug the error - are you able to provide any insight ?

LiQian-XC commented 4 months ago

Can you check the data type of adata.X? If it's not float32, I think you need to change the data type first before running scTour, for example, adata.X = adata.X.astype('float32'). Please let me know if this does not work.

obrien-james commented 4 months ago

Thanks, this was the issue.