MICS-Lab / scyan

Biology-driven deep generative model for cell-type annotation in cytometry. Scyan is an interpretable model that also corrects batch-effect and can be used for debarcoding or population discovery.
https://mics-lab.github.io/scyan/
BSD 3-Clause "New" or "Revised" License
33 stars 1 forks source link

[Bug] predict() fails after training on GPU #32

Closed grst closed 5 months ago

grst commented 6 months ago

Description

model.predict() after batch effect correction before or after batch correction does not work on GPU. It fails with

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Reproducing the issue

adata = sc.read_h5ad(...)
table = pd.read_excel(...)
model = scyan.Scyan(adata, table, prior_std=0.25, lr=0.0001, batch_key="sample_id")
model.fit(accelerator="gpu", profiler="simple")
model.predict()
Traceback ```pytb --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[6], line 1 ----> 1 model.predict() File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/utils.py:145, in _requires_fit..wrapper(model, *args, **kwargs) 140 @wraps(f) 141 def wrapper(model, *args, **kwargs): 142 assert ( 143 model._is_fitted 144 ), "The model has to be trained first, consider running 'model.fit()'" --> 145 return f(model, *args, **kwargs) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__..decorate_context(*args, **kwargs) 24 @functools.wraps(func) 25 def decorate_context(*args, **kwargs): 26 with self.clone(): ---> 27 return func(*args, **kwargs) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/model.py:341, in Scyan.predict(self, key_added, add_levels, log_prob_th) 320 @_requires_fit 321 @torch.no_grad() 322 def predict( (...) 326 log_prob_th: float = -50, 327 ) -> pd.Series: 328 """Model population predictions, i.e. one population is assigned for each cell. Predictions are saved in `adata.obs.scyan_pop` by default. 329 330 !!! note (...) 339 Population predictions (pandas `Series` of length $N$ cells). 340 """ --> 341 df = self.predict_proba() 342 self.adata.obs["scyan_log_probs"] = df["max_log_prob_u"].values 344 populations = df.iloc[:, : self.n_pops].idxmax(axis=1).astype("category") File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/utils.py:145, in _requires_fit..wrapper(model, *args, **kwargs) 140 @wraps(f) 141 def wrapper(model, *args, **kwargs): 142 assert ( 143 model._is_fitted 144 ), "The model has to be trained first, consider running 'model.fit()'" --> 145 return f(model, *args, **kwargs) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__..decorate_context(*args, **kwargs) 24 @functools.wraps(func) 25 def decorate_context(*args, **kwargs): 26 with self.clone(): ---> 27 return func(*args, **kwargs) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/model.py:370, in Scyan.predict_proba(self) 362 @_requires_fit 363 @torch.no_grad() 364 def predict_proba(self) -> pd.DataFrame: 365 """Soft predictions (i.e. an array of probability per population) for each cell. 366 367 Returns: 368 Dataframe of shape `(N, P)` with probabilities for each population. 369 """ --> 370 log_probs = self.dataset_apply( 371 lambda *data: self.module.compute_probabilities(*data)[0] 372 ) 373 probs = torch.softmax(log_probs, dim=1) 375 df = pd.DataFrame(probs.numpy(force=True), columns=self.pop_names) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/model.py:461, in Scyan.dataset_apply(self, func, data) 453 else: 454 loader = DataLoader( 455 TensorDataset(*data), 456 batch_size=self._batch_size, 457 num_workers=self._num_workers, 458 ) 460 return torch.cat( --> 461 [func(*batch) for batch in tqdm(loader, desc="DataLoader")], dim=0 462 ) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/model.py:461, in (.0) 453 else: 454 loader = DataLoader( 455 TensorDataset(*data), 456 batch_size=self._batch_size, 457 num_workers=self._num_workers, 458 ) 460 return torch.cat( --> 461 [func(*batch) for batch in tqdm(loader, desc="DataLoader")], dim=0 462 ) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/model.py:371, in Scyan.predict_proba..(*data) 362 @_requires_fit 363 @torch.no_grad() 364 def predict_proba(self) -> pd.DataFrame: 365 """Soft predictions (i.e. an array of probability per population) for each cell. 366 367 Returns: 368 Dataframe of shape `(N, P)` with probabilities for each population. 369 """ 370 log_probs = self.dataset_apply( --> 371 lambda *data: self.module.compute_probabilities(*data)[0] 372 ) 373 probs = torch.softmax(log_probs, dim=1) 375 df = pd.DataFrame(probs.numpy(force=True), columns=self.pop_names) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/module/scyan_module.py:166, in ScyanModule.compute_probabilities(self, x, covariates, use_temp) 154 def compute_probabilities( 155 self, x: Tensor, covariates: Tensor, use_temp: bool = False 156 ) -> Tuple[Tensor, Tensor, Tensor]: 157 """Compute probabilities used in the loss function. 158 159 Args: (...) 164 Log probabilities of size $(B, P)$, the log det jacobian and the latent expressions of size $(B, M)$. 165 """ --> 166 u, _, ldj_sum = self(x, covariates) 168 log_pi = ( 169 self.log_pi_temperature(-self.hparams.temperature) 170 if use_temp 171 else self.log_pi 172 ) 174 log_probs = self.prior.log_prob(u) + log_pi # size N x P File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(*input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], [] File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/module/scyan_module.py:74, in ScyanModule.forward(self, x, covariates) 64 def forward(self, x: Tensor, covariates: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 65 """Forward implementation, going through the complete flow $f_{\phi}$. 66 67 Args: (...) 72 Tuple of (outputs, covariates, lod_det_jacobian sum) 73 """ ---> 74 return self.real_nvp(x, covariates) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(*input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], [] File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/module/real_nvp.py:61, in RealNVP.forward(self, x, covariates) 51 def forward(self, x: Tensor, covariates: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 52 """Forward implementation, i.e. $f_{\phi}$. 53 54 Args: (...) 59 Tuple of (outputs, covariates, lod_det_jacobian sum) 60 """ ---> 61 return self.module((x, covariates, None)) File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(*input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], [] File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/torch/nn/modules/container.py:204, in Sequential.forward(self, input) 202 def forward(self, input): 203 for module in self: --> 204 input = module(input) 205 return input File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(*input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], [] File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/module/coupling_layer.py:68, in CouplingLayer.forward(self, inputs) 58 """Coupling layer forward function. 59 60 Args: (...) 64 outputs, covariates, lod_det_jacobian sum 65 """ 66 x, covariates, ldj_sum = inputs ---> 68 x_m = x * self.mask 69 st_input = torch.cat([x_m, 100 * covariates], dim=1) 71 s_out = self.sfun(st_input) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! ```

System

Package Version


aiohttp 3.9.3 aiosignal 1.3.1 anndata 0.10.6 anyio 4.3.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 array_api_compat 1.4.1 arrow 1.3.0 asciitree 0.3.3 asttokens 2.4.1 async-lru 2.0.4 async-timeout 4.0.3 attrs 23.2.0 Babel 2.14.0 beautifulsoup4 4.12.3 bleach 6.1.0 certifi 2024.2.2 cffi 1.16.0 charset-normalizer 3.3.2 click 8.1.7 cloudpickle 3.0.0 colorcet 3.1.0 comm 0.2.2 contourpy 1.2.0 cycler 0.12.1 dask 2024.3.1 dask-expr 1.0.3 datashader 0.16.0 debugpy 1.8.1 decorator 5.1.1 defusedxml 0.7.1 et-xmlfile 1.1.0 exceptiongroup 1.2.0 executing 2.0.1 fasteners 0.19 fastjsonschema 2.19.1 fcsparser 0.2.8 fcswrite 0.6.2 FlowUtils 1.0.0 fonttools 4.50.0 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2024.3.0 get-annotations 0.1.2 h11 0.14.0 h5py 3.10.0 httpcore 1.0.4 httpx 0.27.0 idna 3.6 importlib_metadata 7.0.2 importlib_resources 6.3.1 ipykernel 6.29.3 ipylab 1.0.0 ipython 8.18.1 ipywidgets 8.1.2 isoduration 20.11.0 jedi 0.19.1 Jinja2 3.1.3 joblib 1.3.2 json5 0.9.24 jsonpointer 2.4 jsonschema 4.21.1 jsonschema-specifications 2023.12.1 jupyter_client 8.6.1 jupyter_core 5.7.2 jupyter-events 0.9.1 jupyter-lsp 2.2.4 jupyter_server 2.13.0 jupyter_server_terminals 0.5.3 jupyterlab 4.1.5 jupyterlab_pygments 0.3.0 jupyterlab_server 2.25.4 jupyterlab_widgets 3.0.10 kiwisolver 1.4.5 lamin_utils 0.13.0 legacy-api-wrap 1.4 lightning-utilities 0.10.1 llvmlite 0.42.0 locket 1.0.0 MarkupSafe 2.1.5 matplotlib 3.8.3 matplotlib-inline 0.1.6 mistune 3.0.2 multidict 6.0.5 multipledispatch 1.0.0 natsort 8.4.0 nbclient 0.10.0 nbconvert 7.16.2 nbformat 5.10.3 nbproject 0.10.1 nest-asyncio 1.6.0 networkx 3.2.1 notebook 7.1.2 notebook_shim 0.2.4 numba 0.59.0 numcodecs 0.12.1 numpy 1.26.4 nvidia-cublas-cu11 11.10.3.66 nvidia-cuda-nvrtc-cu11 11.7.99 nvidia-cuda-runtime-cu11 11.7.99 nvidia-cudnn-cu11 8.5.0.96 openpyxl 3.1.2 orjson 3.9.15 overrides 7.7.0 packaging 24.0 pandas 2.2.1 pandocfilters 1.5.1 param 2.0.2 parso 0.8.3 partd 1.4.1 patsy 0.5.6 pexpect 4.9.0 pillow 10.2.0 pip 24.0 platformdirs 4.2.0 prometheus_client 0.20.0 prompt-toolkit 3.0.43 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 15.0.1 pycparser 2.21 pyct 0.5.0 pydantic 1.10.14 Pygments 2.17.2 pynndescent 0.5.11 pyparsing 3.1.2 python-dateutil 2.9.0.post0 python-json-logger 2.0.7 pytometry 0.1.4 pytorch-lightning 1.9.5 pytz 2024.1 PyYAML 6.0.1 pyzmq 25.1.2 readfcs 1.1.7 referencing 0.34.0 requests 2.31.0 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rpds-py 0.18.0 scanpy 1.10.0rc2 scikit-learn 1.4.1.post1 scipy 1.12.0 scyan 1.5.4 seaborn 0.13.2 Send2Trash 1.8.2 session_info 1.0.0 setuptools 69.2.0 six 1.16.0 sniffio 1.3.1 soupsieve 2.5 stack-data 0.6.3 statsmodels 0.14.1 stdlib-list 0.10.0 terminado 0.18.1 threadpoolctl 3.3.0 tinycss2 1.2.1 tomli 2.0.1 toolz 0.12.1 torch 1.13.1 torchmetrics 1.3.1 tornado 6.4 tqdm 4.66.2 traitlets 5.14.2 types-python-dateutil 2.9.0.20240316 typing_extensions 4.10.0 tzdata 2024.1 umap-learn 0.5.5 uri-template 1.3.0 urllib3 2.2.1 wcwidth 0.2.13 webcolors 1.13 webencodings 0.5.1 websocket-client 1.7.0 wheel 0.42.0 widgetsnbextension 4.0.10 xarray 2024.2.0 yarl 1.9.4 zarr 2.17.1 zipp 3.18.1

quentinblampey commented 6 months ago

Hello @grst, thanks for posting this issue, I was able to reproduce it.

After training, pytorch_lightning switches back to cpu, but it seems that it doesn't do it for the model.dataset. This is very strange, I'll try more recent versions of pytorch lightning to see if this is fixed (I need to allow more recent dependencies versions, as you said in another issue).

Meanwhile, you can do this operation right after model.fit:

model.dataset.tensors = tuple(x.to("cpu") for x in model.dataset.tensors)
model.predict() # this should now run

I also don't really like the fact that pytorch lightning switches everything back to cpu, which means we need to re-use the trainer during prediction if we want to re-use the GPU, or else we have the prediction on CPU (what is currently done)

quentinblampey commented 5 months ago

Actually, it was an easy fix, I simply removed dataset from the attributes of the model GPU usage should now work fine on scyan==1.6.1