BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
324 stars 58 forks source link

cell2location.models.RegressionModel.load takes one hour with GPU and seems to be retraining #185

Closed bednarsky closed 2 years ago

bednarsky commented 2 years ago

I am running cell2location as described in the tutorial. Reloading the model to continue analysis takes about an hour, and since it goes through epochs and shows training error, it seems to me as if the model was retraining - see screenshot:

Screenshot 2022-07-29 at 13 27 41

This was unexpected to me, and I couldn't find any discussion on this online. Is this expected behaviour, or is the fault lying with me? (I am happy to give more details if this is not normal behaviour).

vitkl commented 2 years ago

Hi @adamgayoso

do you have any idea what is going on? Looks like scvi-tools issue or scvi-tools cell2location interaction issue.

adamgayoso commented 2 years ago

I think what is happening is that we have to run 1 training step to initialize all the latent pyro params and THEN we reload the saved model params.

We should capture all that info so it doesn't confuse people.

adamgayoso commented 2 years ago

It also shouldn't take an hour -- @bednarsky what versions are you using?

vitkl commented 2 years ago

We should capture all that info so it doesn't confuse people.

Do you mean to hide the messages associated with one auxiliary training step?

bednarsky commented 2 years ago

I also have the same issue with the second model in the tutorial process, except that there, training and loading time are much longer.

An additional issue is that I don't have enough memory to train with full batch size, and when loading it also has the same memory error.

I am working with a GPU on Azure (NVIDIA Corporation GK210GL [Tesla K80] (rev a1)), CUDA Version: 11.4

These are my package versions: packages in environment at /home/azureuser/miniconda3/envs/cell2loc_azure:

Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu absl-py 1.2.0 pypi_0 pypi aiohttp 3.8.1 pypi_0 pypi aiosignal 1.2.0 pypi_0 pypi anndata 0.8.0 pypi_0 pypi argon2-cffi 21.3.0 pypi_0 pypi argon2-cffi-bindings 21.2.0 pypi_0 pypi asttokens 2.0.5 pypi_0 pypi async-timeout 4.0.2 pypi_0 pypi attrs 22.1.0 pypi_0 pypi backcall 0.2.0 pypi_0 pypi beautifulsoup4 4.11.1 pypi_0 pypi bleach 5.0.1 pypi_0 pypi ca-certificates 2022.07.19 h06a4308_0 cachetools 5.2.0 pypi_0 pypi cell2location 0.1 pypi_0 pypi certifi 2022.6.15 py39h06a4308_0 cffi 1.15.1 pypi_0 pypi charset-normalizer 2.1.0 pypi_0 pypi chex 0.1.3 pypi_0 pypi colorama 0.4.5 pypi_0 pypi commonmark 0.9.1 pypi_0 pypi cycler 0.11.0 pypi_0 pypi debugpy 1.6.2 pypi_0 pypi decorator 5.1.1 pypi_0 pypi defusedxml 0.7.1 pypi_0 pypi dm-tree 0.1.7 pypi_0 pypi docrep 0.3.2 pypi_0 pypi entrypoints 0.4 pypi_0 pypi et-xmlfile 1.1.0 pypi_0 pypi etils 0.6.0 pypi_0 pypi executing 0.9.1 pypi_0 pypi fastjsonschema 2.16.1 pypi_0 pypi flax 0.5.3 pypi_0 pypi fonttools 4.34.4 pypi_0 pypi frozenlist 1.3.0 pypi_0 pypi fsspec 2022.7.0 pypi_0 pypi google-auth 2.9.1 pypi_0 pypi google-auth-oauthlib 0.4.6 pypi_0 pypi grpcio 1.47.0 pypi_0 pypi h5py 3.7.0 pypi_0 pypi idna 3.3 pypi_0 pypi igraph 0.9.11 pypi_0 pypi importlib-metadata 4.12.0 pypi_0 pypi importlib-resources 5.9.0 pypi_0 pypi ipykernel 6.15.1 pypi_0 pypi ipython 8.4.0 pypi_0 pypi ipython-genutils 0.2.0 pypi_0 pypi ipywidgets 7.7.1 pypi_0 pypi jax 0.3.15 pypi_0 pypi jaxlib 0.3.15 pypi_0 pypi jedi 0.18.1 pypi_0 pypi jinja2 3.1.2 pypi_0 pypi joblib 1.1.0 pypi_0 pypi jsonschema 4.8.0 pypi_0 pypi jupyter-client 7.3.4 pypi_0 pypi jupyter-core 4.11.1 pypi_0 pypi jupyterlab-pygments 0.2.2 pypi_0 pypi jupyterlab-widgets 1.1.1 pypi_0 pypi kiwisolver 1.4.4 pypi_0 pypi ld_impl_linux-64 2.38 h1181459_1 leidenalg 0.8.10 pypi_0 pypi libffi 3.3 he6710b0_2 libgcc-ng 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libstdcxx-ng 11.2.0 h1234567_1 llvmlite 0.39.0 pypi_0 pypi markdown 3.4.1 pypi_0 pypi markupsafe 2.1.1 pypi_0 pypi matplotlib 3.5.2 pypi_0 pypi matplotlib-inline 0.1.3 pypi_0 pypi mistune 0.8.4 pypi_0 pypi msgpack 1.0.4 pypi_0 pypi mudata 0.2.0 pypi_0 pypi multidict 6.0.2 pypi_0 pypi multipledispatch 0.6.0 pypi_0 pypi natsort 8.1.0 pypi_0 pypi nbclient 0.6.6 pypi_0 pypi nbconvert 6.5.0 pypi_0 pypi nbformat 5.4.0 pypi_0 pypi ncurses 6.3 h5eee18b_3 nest-asyncio 1.5.5 pypi_0 pypi networkx 2.8.5 pypi_0 pypi notebook 6.4.12 pypi_0 pypi numba 0.56.0 pypi_0 pypi numpy 1.22.4 pypi_0 pypi numpyro 0.10.0 pypi_0 pypi oauthlib 3.2.0 pypi_0 pypi opencv-python 4.6.0.66 pypi_0 pypi openpyxl 3.0.10 pypi_0 pypi openssl 1.1.1q h7f8727e_0 opt-einsum 3.3.0 pypi_0 pypi optax 0.1.3 pypi_0 pypi packaging 21.3 pypi_0 pypi pandas 1.4.3 pypi_0 pypi pandocfilters 1.5.0 pypi_0 pypi parso 0.8.3 pypi_0 pypi patsy 0.5.2 pypi_0 pypi pexpect 4.8.0 pypi_0 pypi pickleshare 0.7.5 pypi_0 pypi pillow 9.2.0 pypi_0 pypi pip 22.1.2 py39h06a4308_0 prometheus-client 0.14.1 pypi_0 pypi prompt-toolkit 3.0.30 pypi_0 pypi protobuf 3.20.1 pypi_0 pypi psutil 5.9.1 pypi_0 pypi ptyprocess 0.7.0 pypi_0 pypi pure-eval 0.2.2 pypi_0 pypi pyasn1 0.4.8 pypi_0 pypi pyasn1-modules 0.2.8 pypi_0 pypi pycparser 2.21 pypi_0 pypi pydeprecate 0.3.2 pypi_0 pypi pygments 2.12.0 pypi_0 pypi pynndescent 0.5.7 pypi_0 pypi pyparsing 3.0.9 pypi_0 pypi pyro-api 0.1.2 pypi_0 pypi pyro-ppl 1.8.1 pypi_0 pypi pyrsistent 0.18.1 pypi_0 pypi python 3.9.12 h12debd9_1 python-dateutil 2.8.2 pypi_0 pypi python-igraph 0.9.11 pypi_0 pypi pytorch-lightning 1.6.5 pypi_0 pypi pytz 2022.1 pypi_0 pypi pyyaml 6.0 pypi_0 pypi pyzmq 23.2.0 pypi_0 pypi readline 8.1.2 h7f8727e_1 requests 2.28.1 pypi_0 pypi requests-oauthlib 1.3.1 pypi_0 pypi rich 11.2.0 pypi_0 pypi rsa 4.9 pypi_0 pypi scanpy 1.9.1 pypi_0 pypi scikit-learn 1.1.1 pypi_0 pypi scipy 1.8.1 pypi_0 pypi scvi-tools 0.17.1 pypi_0 pypi seaborn 0.11.2 pypi_0 pypi send2trash 1.8.0 pypi_0 pypi session-info 1.0.0 pypi_0 pypi setuptools 61.2.0 py39h06a4308_0 six 1.16.0 pypi_0 pypi soupsieve 2.3.2.post1 pypi_0 pypi sqlite 3.38.5 hc218d9a_0 stack-data 0.3.0 pypi_0 pypi statsmodels 0.13.2 pypi_0 pypi stdlib-list 0.8.0 pypi_0 pypi tensorboard 2.9.0 pypi_0 pypi tensorboard-data-server 0.6.1 pypi_0 pypi tensorboard-plugin-wit 1.8.1 pypi_0 pypi tensorstore 0.1.21 pypi_0 pypi terminado 0.15.0 pypi_0 pypi texttable 1.6.4 pypi_0 pypi threadpoolctl 3.1.0 pypi_0 pypi tinycss2 1.1.1 pypi_0 pypi tk 8.6.12 h1ccaba5_0 toolz 0.12.0 pypi_0 pypi torch 1.12.0 pypi_0 pypi torchmetrics 0.9.3 pypi_0 pypi tornado 6.2 pypi_0 pypi tqdm 4.64.0 pypi_0 pypi traitlets 5.3.0 pypi_0 pypi typing-extensions 4.3.0 pypi_0 pypi tzdata 2022a hda174b7_0 umap-learn 0.5.3 pypi_0 pypi urllib3 1.26.11 pypi_0 pypi wcwidth 0.2.5 pypi_0 pypi webencodings 0.5.1 pypi_0 pypi werkzeug 2.2.1 pypi_0 pypi wheel 0.37.1 pyhd3eb1b0_0 widgetsnbextension 3.6.1 pypi_0 pypi xz 5.2.5 h7f8727e_1 yarl 1.7.2 pypi_0 pypi zipp 3.8.1 pypi_0 pypi zlib 1.2.12 h7f8727e_2

I have installed the environment as recommended, these were my commands: export PYTHONNOUSERSITE="literallyanyletters" conda create -y -n cell2loc_azure python=3.9 conda activate cell2loc_azure pip install git+https://github.com/BayraktarLab/cell2location.git#egg=cell2location[tutorials] python -m ipykernel install --user --name=cell2loc_azure --display-name='cell2loc_azure'

Thank you for your help!

livyring commented 2 years ago

This is also happening to me! Is there any update on this?

adamgayoso commented 2 years ago

Is it only happening for the regression model?

adamgayoso commented 2 years ago

@vitkl @jjhong922 the issue is that

mod.train(max_steps=1, max_epochs=MAX_EPOCHS)

doesn't respect max_steps for the regression model and always uses max epochs, which I don't understand. The loading should only do one opt step to reinitialize the params.

livyring commented 2 years ago

Yes only for the regression model

adamgayoso commented 2 years ago

Fix is here: https://github.com/scverse/scvi-tools/pull/1636

For a workaroud now, just give a keyboard interrupt as soon as it looks like it's training (ctrl + c), it will then stop and then properly load -- no reason to worry that your old params won't be restored, as it happens after this "training" part. This training part is necessary to reinitialize all the pyro params.

bednarsky commented 2 years ago

Is it only happening for the regression model?

For me happens for regression model and mapping model! Thanks for all the answers!

vitkl commented 2 years ago

The latest scvi-tools contains this fix now (probably just the GitHub version for now).

adamgayoso commented 2 years ago

This is fixed now in scvi-tools 0.17.2, @vitkl you should consider lower bounding your requirement.