BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
321 stars 58 forks source link

GPU error in mod.train #269

Open renyuan1988 opened 1 year ago

renyuan1988 commented 1 year ago

Hi Thanks very much for developing such a briliant package. I met an error in the mod train module and hoping you could give me some clues. Currently I am using the HPC cluster in our institution which has a GPU node (48 cores/4*NV-A100-40G). I did my job in a jupyter notebook and when I came to the "mod.train" module , in the "Estimation of reference cell type signatures (NB regression)", and ran the following code: mod.train(max_epochs=250, use_gpu=True), it retured the error:

```python "RuntimeError: Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: mps". I don't know which part went wrong and looking forward to your direction. Many many thanks for your time. Below are my codes, outputs and session info: mod.train(max_epochs=250, use_gpu=True) RuntimeError Traceback (most recent call last) Cell In[15], line 1 ----> 1 mod.train(max_epochs=250, use_gpu=True) File ~/anaconda3/envs/test_scvi16_cuda113/lib/python3.9/site-packages/cell2location/models/reference/_reference_model.py:157, in RegressionModel.train(self, max_epochs, batch_size, train_size, lr, **kwargs) 154 kwargs["train_size"] = train_size 155 kwargs["lr"] = lr --> 157 super().train(**kwargs) File ~/anaconda3/envs/test_scvi16_cuda113/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:164, in PyroSviTrainMixin.train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, lr, training_plan, plan_kwargs, **trainer_kwargs) 161 trainer_kwargs["callbacks"] = [] 162 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) --> 164 runner = self._train_runner_cls( 165 self, 166 training_plan=training_plan, 167 data_splitter=data_splitter, 168 max_epochs=max_epochs, 169 use_gpu=use_gpu, 170 **trainer_kwargs, 171 ) 172 return runner() File ~/anaconda3/envs/test_scvi16_cuda113/lib/python3.9/site-packages/scvi/train/_trainrunner.py:64, in TrainRunner.__init__(self, model, training_plan, data_splitter, max_epochs, use_gpu, **trainer_kwargs) 62 self.data_splitter = data_splitter 63 self.model = model ---> 64 accelerator, lightning_devices, device = parse_use_gpu_arg(use_gpu) 65 self.accelerator = accelerator 66 self.lightning_devices = lightning_devices File ~/anaconda3/envs/test_scvi16_cuda113/lib/python3.9/site-packages/scvi/model/_utils.py:59, in parse_use_gpu_arg(use_gpu, return_device) 57 accelerator = "mps" 58 lightning_devices = 1 ---> 59 device = torch.device(current) 60 # Also captures bool case 61 elif isinstance(use_gpu, int): RuntimeError: Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: mps ``` ```python cell2location.utils.list_imported_modules() sys 3.9.16 (main, Mar 8 2023, 14:00:05) [GCC 11.2.0] re 2.2.1 ipykernel._version 6.21.3 json 2.0.9 jupyter_client._version 8.0.3 platform 1.0.8 _ctypes 1.1.0 ctypes 1.1.0 zmq.sugar.version 25.0.1 zmq.sugar 25.0.1 zmq 25.0.1 logging 0.5.1.2 traitlets._version 5.9.0 traitlets 5.9.0 jupyter_core.version 5.3.0 jupyter_core 5.3.0 tornado 6.2 zlib 1.0 _curses b'2.2' socketserver 0.4 argparse 1.1 dateutil._version 2.8.2 dateutil 2.8.2 six 1.16.0 _decimal 1.70 decimal 1.70 platformdirs.version 3.1.1 platformdirs 3.1.1 _csv 1.0 csv 1.0 jupyter_client 8.0.3 ipykernel 6.21.3 IPython.core.release 8.11.0 executing.version 1.2.0 executing 1.2.0 pure_eval.version 0.2.2 pure_eval 0.2.2 stack_data.version 0.6.2 stack_data 0.6.2 pygments 2.14.0 ptyprocess 0.7.0 pexpect 4.8.0 IPython.core.crashhandler 8.11.0 pickleshare 0.7.5 backcall 0.2.0 decorator 5.1.1 _sqlite3 2.6.0 sqlite3.dbapi2 2.6.0 sqlite3 2.6.0 wcwidth 0.2.6 prompt_toolkit 3.0.38 parso 0.8.3 jedi 0.18.2 urllib.request 3.9 IPython.core.magics.code 8.11.0 IPython 8.11.0 comm 0.1.2 psutil 5.9.4 debugpy.public_api 1.6.6 debugpy 1.6.6 xmlrpc.client 3.9 http.server 0.6 pkg_resources._vendor.more_itertools 8.12.0 pkg_resources.extern.more_itertools 8.12.0 pkg_resources._vendor.appdirs 1.4.3 pkg_resources.extern.appdirs 1.4.3 pkg_resources._vendor.packaging.__about__ 21.3 pkg_resources._vendor.packaging 21.3 pkg_resources.extern.packaging 21.3 pkg_resources._vendor.pyparsing 3.0.9 pkg_resources.extern.pyparsing 3.0.9 _pydevd_frame_eval.vendored.bytecode 0.13.0.dev _pydev_bundle.fsnotify 0.1.5 pydevd 2.9.5 packaging.__about__ 22.0 packaging 22.0 scanpy._metadata 1.9.3 mkl 2.4.0 numpy.version 1.23.5 numpy.core._multiarray_umath 3.1 numpy.core 1.23.5 numpy.linalg._umath_linalg 0.1.5 numpy.lib 1.23.5 numpy 1.23.5 scipy.version 1.10.1 scipy 1.10.1 scipy.sparse.linalg._isolve._iterative b'$Revision: $' scipy._lib.decorator 4.0.5 scipy.linalg._fblas b'$Revision: $' scipy.linalg._flapack b'$Revision: $' scipy.linalg._flinalg b'$Revision: $' scipy.sparse.linalg._eigen.arpack._arpack b'$Revision: $' anndata._metadata 0.8.0 h5py.version 3.8.0 h5py 3.8.0 natsort 8.3.1 pytz 2022.7.1 numexpr.version 2.8.4 numexpr 2.8.4 tarfile 0.9.0 pandas 1.5.3 anndata 0.8.0 yaml 6.0 llvmlite 0.39.1 numba.cloudpickle 1.6.0 numba.misc.appdirs 1.4.1 numba 0.56.4 setuptools._distutils 3.9.16 setuptools.version 65.6.3 setuptools._vendor.packaging.__about__ 21.3 setuptools._vendor.packaging 21.3 setuptools.extern.packaging 21.3 setuptools._vendor.ordered_set 3.1 setuptools.extern.ordered_set 3.1 setuptools._vendor.more_itertools 8.8.0 setuptools.extern.more_itertools 8.8.0 setuptools._vendor.pyparsing 3.0.9 setuptools.extern.pyparsing 3.0.9 setuptools 65.6.3 distutils 3.9.16 joblib.externals.cloudpickle 2.2.0 joblib.externals.loky 3.3.0 joblib 1.2.0 sklearn.utils._joblib 1.2.0 scipy.special._specfun b'$Revision: $' scipy.optimize._minpack2 b'$Revision: $' scipy.optimize._lbfgsb b'$Revision: $' scipy.optimize._cobyla b'$Revision: $' scipy.optimize._slsqp b'$Revision: $' scipy.optimize.__nnls b'$Revision: $' scipy.linalg._interpolative b'$Revision: $' scipy.integrate._vode b'$Revision: $' scipy.integrate._dop b'$Revision: $' scipy.integrate._lsoda b'$Revision: $' scipy.interpolate.dfitpack b'$Revision: $' scipy._lib._uarray 0.8.8.dev0+aa94c5a4.scipy scipy.stats._statlib b'$Revision: $' scipy.stats._mvn b'$Revision: $' threadpoolctl 3.1.0 sklearn.base 1.2.2 sklearn.utils._show_versions 1.2.2 sklearn 1.2.2 matplotlib._version 3.7.1 PIL._version 9.4.0 PIL 9.4.0 PIL._deprecate 9.4.0 PIL.Image 9.4.0 pyparsing 3.0.9 cycler 0.10.0 kiwisolver._cext 1.4.4 kiwisolver 1.4.4 matplotlib 3.7.1 texttable 1.6.7 igraph.version 0.10.4 igraph 0.10.4 leidenalg.version 0.9.1 leidenalg 0.9.1 scanpy 1.9.3 torch.version 1.11.0+cu113 torch.torch_version 1.11.0+cu113 torch.cuda.nccl (2, 10, 3) torch.backends.cudnn 8200 tqdm._dist_ver 4.65.0 tqdm.version 4.65.0 tqdm.cli 4.65.0 tqdm 4.65.0 torch 1.11.0+cu113 opt_einsum v3.3.0 pyro._version 1.8.4 pyro 1.8.4 attr 22.2.0 pytorch_lightning.__version__ 1.9.4 lightning_fabric.__version__ 1.9.4 lightning_utilities.__about__ 0.8.0 lightning_utilities 0.8.0 fsspec 2023.3.0 lightning_fabric 1.9.4 torchmetrics.__about__ 0.11.4 urllib3.packages.six 1.16.0 urllib3._version 1.26.15 ipaddress 1.0 urllib3.util.ssl_match_hostname 3.5.0.1 urllib3.connection 1.26.15 urllib3 1.26.15 charset_normalizer.version 3.1.0 charset_normalizer 3.1.0 requests.packages.urllib3.packages.six 1.16.0 requests.packages.urllib3._version 1.26.15 requests.packages.urllib3.util.ssl_match_hostname 3.5.0.1 requests.packages.urllib3.connection 1.26.15 requests.packages.urllib3 1.26.15 idna.package_data 3.4 idna.idnadata 15.0.0 idna 3.4 requests.packages.idna.package_data 3.4 requests.packages.idna.idnadata 15.0.0 requests.packages.idna 3.4 certifi 2022.12.07 requests.__version__ 2.28.2 requests.utils 2.28.2 requests 2.28.2 xml.etree.ElementTree 1.3.0 torchvision.version 0.12.0+cu113 torchvision 0.12.0+cu113 torchmetrics 0.11.4 pytorch_lightning 1.9.4 jaxlib.version 0.4.6 jaxlib 0.4.6 jax.version 0.4.6 etils 1.1.0 jax.lib 0.4.6 jax 0.4.6 tree 0.1.8 xml.sax.handler 2.0beta toolz 0.12.0 chex 0.1.6 mudata 0.2.1 docrep 0.3.2 msgpack 1.0.5 flax.version 0.6.7 flax 0.6.7 optax 0.1.4 multipledispatch 0.6.0 numpyro.version 0.11.0 numpyro 0.11.0 scvi 0.20.2 pynndescent 0.5.8 umap 0.5.3 cell2location 0.1.3 ```
vitkl commented 1 year ago

This could be a package version mismatch. I would try creating a conda environment from scratch https://github.com/BayraktarLab/cell2location#installation.

However, I see that torch, scvi, cell2location, pytorch-lightning versions look ok.

This could also be an issue with incorrectly working conda environment. Try command line export PYTHONNOUSERSITE="literallyanyletters" before activating conda environments and before installing packages.

renyuan1988 commented 1 year ago

Hi Thanks so much for your quick reply. I'm 100% sure that I ran the command line export PYTHONNOUSERSITE="literallyanyletters" before my operation, so I decided to reinstall the conda environment. unfortunately, I can't open the weblink you offered. Could you provide me the installation commands in detail? Thanks!

renyuan1988 commented 1 year ago

Hi I have rebulit the conda envrionment from scratch and the GPU problem has been solved, everthing went fine in the single slide test. I'm going to test multiple slides, yet I 'm a little confused about the slides combination codes you have offered. I have no idea about the correct format of the annotation csv file, could you give me some detail instrutions about this?