Closed grst closed 5 months ago
Hello @grst, thanks for posting this issue, I was able to reproduce it.
After training, pytorch_lightning switches back to cpu, but it seems that it doesn't do it for the model.dataset
. This is very strange, I'll try more recent versions of pytorch lightning to see if this is fixed (I need to allow more recent dependencies versions, as you said in another issue).
Meanwhile, you can do this operation right after model.fit
:
model.dataset.tensors = tuple(x.to("cpu") for x in model.dataset.tensors)
model.predict() # this should now run
I also don't really like the fact that pytorch lightning switches everything back to cpu, which means we need to re-use the trainer during prediction if we want to re-use the GPU, or else we have the prediction on CPU (what is currently done)
Actually, it was an easy fix, I simply removed dataset
from the attributes of the model
GPU usage should now work fine on scyan==1.6.1
Description
model.predict()
after batch effect correctionbefore or after batch correction does not work on GPU. It fails withReproducing the issue
Traceback
```pytb --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[6], line 1 ----> 1 model.predict() File /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/scyan/utils.py:145, in _requires_fit.System
Dependencies versions
Package Version
aiohttp 3.9.3 aiosignal 1.3.1 anndata 0.10.6 anyio 4.3.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 array_api_compat 1.4.1 arrow 1.3.0 asciitree 0.3.3 asttokens 2.4.1 async-lru 2.0.4 async-timeout 4.0.3 attrs 23.2.0 Babel 2.14.0 beautifulsoup4 4.12.3 bleach 6.1.0 certifi 2024.2.2 cffi 1.16.0 charset-normalizer 3.3.2 click 8.1.7 cloudpickle 3.0.0 colorcet 3.1.0 comm 0.2.2 contourpy 1.2.0 cycler 0.12.1 dask 2024.3.1 dask-expr 1.0.3 datashader 0.16.0 debugpy 1.8.1 decorator 5.1.1 defusedxml 0.7.1 et-xmlfile 1.1.0 exceptiongroup 1.2.0 executing 2.0.1 fasteners 0.19 fastjsonschema 2.19.1 fcsparser 0.2.8 fcswrite 0.6.2 FlowUtils 1.0.0 fonttools 4.50.0 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2024.3.0 get-annotations 0.1.2 h11 0.14.0 h5py 3.10.0 httpcore 1.0.4 httpx 0.27.0 idna 3.6 importlib_metadata 7.0.2 importlib_resources 6.3.1 ipykernel 6.29.3 ipylab 1.0.0 ipython 8.18.1 ipywidgets 8.1.2 isoduration 20.11.0 jedi 0.19.1 Jinja2 3.1.3 joblib 1.3.2 json5 0.9.24 jsonpointer 2.4 jsonschema 4.21.1 jsonschema-specifications 2023.12.1 jupyter_client 8.6.1 jupyter_core 5.7.2 jupyter-events 0.9.1 jupyter-lsp 2.2.4 jupyter_server 2.13.0 jupyter_server_terminals 0.5.3 jupyterlab 4.1.5 jupyterlab_pygments 0.3.0 jupyterlab_server 2.25.4 jupyterlab_widgets 3.0.10 kiwisolver 1.4.5 lamin_utils 0.13.0 legacy-api-wrap 1.4 lightning-utilities 0.10.1 llvmlite 0.42.0 locket 1.0.0 MarkupSafe 2.1.5 matplotlib 3.8.3 matplotlib-inline 0.1.6 mistune 3.0.2 multidict 6.0.5 multipledispatch 1.0.0 natsort 8.4.0 nbclient 0.10.0 nbconvert 7.16.2 nbformat 5.10.3 nbproject 0.10.1 nest-asyncio 1.6.0 networkx 3.2.1 notebook 7.1.2 notebook_shim 0.2.4 numba 0.59.0 numcodecs 0.12.1 numpy 1.26.4 nvidia-cublas-cu11 11.10.3.66 nvidia-cuda-nvrtc-cu11 11.7.99 nvidia-cuda-runtime-cu11 11.7.99 nvidia-cudnn-cu11 8.5.0.96 openpyxl 3.1.2 orjson 3.9.15 overrides 7.7.0 packaging 24.0 pandas 2.2.1 pandocfilters 1.5.1 param 2.0.2 parso 0.8.3 partd 1.4.1 patsy 0.5.6 pexpect 4.9.0 pillow 10.2.0 pip 24.0 platformdirs 4.2.0 prometheus_client 0.20.0 prompt-toolkit 3.0.43 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 15.0.1 pycparser 2.21 pyct 0.5.0 pydantic 1.10.14 Pygments 2.17.2 pynndescent 0.5.11 pyparsing 3.1.2 python-dateutil 2.9.0.post0 python-json-logger 2.0.7 pytometry 0.1.4 pytorch-lightning 1.9.5 pytz 2024.1 PyYAML 6.0.1 pyzmq 25.1.2 readfcs 1.1.7 referencing 0.34.0 requests 2.31.0 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rpds-py 0.18.0 scanpy 1.10.0rc2 scikit-learn 1.4.1.post1 scipy 1.12.0 scyan 1.5.4 seaborn 0.13.2 Send2Trash 1.8.2 session_info 1.0.0 setuptools 69.2.0 six 1.16.0 sniffio 1.3.1 soupsieve 2.5 stack-data 0.6.3 statsmodels 0.14.1 stdlib-list 0.10.0 terminado 0.18.1 threadpoolctl 3.3.0 tinycss2 1.2.1 tomli 2.0.1 toolz 0.12.1 torch 1.13.1 torchmetrics 1.3.1 tornado 6.4 tqdm 4.66.2 traitlets 5.14.2 types-python-dateutil 2.9.0.20240316 typing_extensions 4.10.0 tzdata 2024.1 umap-learn 0.5.5 uri-template 1.3.0 urllib3 2.2.1 wcwidth 0.2.13 webcolors 1.13 webencodings 0.5.1 websocket-client 1.7.0 wheel 0.42.0 widgetsnbextension 4.0.10 xarray 2024.2.0 yarl 1.9.4 zarr 2.17.1 zipp 3.18.1