cistrome / MIRA

Python package for analysis of multiomic single cell RNA-seq and ATAC-seq.
52 stars 7 forks source link

NITE model prediction expected scalar type Double but found Float #34

Closed Yansr3 closed 6 months ago

Yansr3 commented 8 months ago

I'm following the tutorial on my data to train the NITE model. When I run the predict function for nitemodel, I encountered this runtime error.

rp_args = dict(expr_adata = rna_data, atac_adata = atac_data) nitemodel.predict(**rp_args)


RuntimeError Traceback (most recent call last) File ~/miniconda3/envs/Mira/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e:

File ~/miniconda3/envs/Mira/lib/python3.9/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, *kwargs) 11 with context: ---> 12 return fn(args, **kwargs)

File ~/miniconda3/envs/Mira/lib/python3.9/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, *kwargs) 11 with context: ---> 12 return fn(args, **kwargs)

File ~/miniconda3/envs/Mira/lib/python3.9/site-packages/mira/rp_model/rp_model.py:927, in GeneModel.model(self, gene_expr, correction_vector, softmax_denom, read_depth, upstream_weights, upstream_distances, downstream_weights, downstream_distances, promoter_weights, NITE_features) 926 if self.use_NITE_features: --> 927 f_Z = f_Z + torch.matmul(NITE_features, torch.unsqueeze(a_NITE, 0).T).reshape(-1) 929 expr_prediction = gamma* self.bn(f_Z.reshape((-1,1)).float()).reshape(-1) + bias

RuntimeError: expected scalar type Double but found Float

AllenWLynch commented 6 months ago

This is a problem with new datatype specifications for numpy. If you install the current pre-release version this issue is resolved:

pip install mira-multiome==2.1.1a4

Yansr3 commented 5 months ago

Sorry for bothering again. I tried with version 2.1.1a4, but encountered a slighlty different error when trying to run nitemodel.predict. I am writing to see if I could ask for some advice on the problem.

rp_args = dict(expr_adata = rna_data, atac_adata = atac_data)
nitemodel.predict(**rp_args)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File [~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174), in TraceHandler.__call__(self, *args, **kwargs)
    [173](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:173) try:
--> [174](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174)     ret = self.fn(*args, **kwargs)
    [175](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:175) except (ValueError, RuntimeError) as e:

File [~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/messenger.py:12](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/messenger.py:12), in _context_wrap(context, fn, *args, **kwargs)
     [11](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/messenger.py:11) with context:
---> [12](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/messenger.py:12)     return fn(*args, **kwargs)

File [~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/messenger.py:12](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/messenger.py:12), in _context_wrap(context, fn, *args, **kwargs)
     [11](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/messenger.py:11) with context:
---> [12](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/pyro/poutine/messenger.py:12)     return fn(*args, **kwargs)

File [~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/mira/rp_model/rp_model.py:927](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/mira/rp_model/rp_model.py:927), in GeneModel.model(self, gene_expr, correction_vector, softmax_denom, read_depth, upstream_weights, upstream_distances, downstream_weights, downstream_distances, promoter_weights, NITE_features)
    [926](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/mira/rp_model/rp_model.py:926) if self.use_NITE_features:
--> [927](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/mira/rp_model/rp_model.py:927)     f_Z = f_Z + torch.matmul(NITE_features, torch.unsqueeze(a_NITE, 0).T).reshape(-1)
    [929](https://file+.vscode-resource.vscode-cdn.net/Users/siruiyan/Workspace/EryProject/mira/trimod/RP_model/~/.pyenv/versions/3.11.5/envs/mira-2.1.1a4/lib/python3.11/site-packages/mira/rp_model/rp_model.py:929) expr_prediction = gamma* self.bn(f_Z.reshape((-1,1)).float()).reshape(-1) + bias

RuntimeError: expected m1 and m2 to have the same dtype, but got: double != float

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
...
  NITE_TJP2/a_NITE dist 18 |
                  value 18 |
Trace Shapes:
 Param Sites:
Sample Sites:

Before training the nite model, I used the object for lite model training and predict. Both .fit and .predict worked for the model. The numpy version in my environment is 1.26.2.

mira-multiome             2.1.1a4
numpy                     1.26.2
Yansr3 commented 5 months ago

Just for information, this may not be the best way to solve the problem. But I got the nitemodel.predict work by adding .float() after NITE_features in line 927 of rp_model.py.

if self.use_NITE_features:
    f_Z = f_Z + torch.matmul(NITE_features.float(), torch.unsqueeze(a_NITE, 0).T).reshape(-1)