BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
324 stars 58 forks source link

Passing axis to mod.plot_history raises AttributeError #190

Closed bednarsky closed 2 years ago

bednarsky commented 2 years ago

When training the signature estimation model in a script, I tried to save the model history with the function mod.plot_history(). For this, you have added an option to pass an axis to the function via the ax parameter. I tried to do this like so:

fig, ax = plt.subplots()
mod.plot_history(20, ax=ax)
fig.savefig(QC_dir / 'training_history_signature.png', dpi=300)
image

fyi, saving still works if you don't pass an axis - but for filter_genes (https://cell2location.readthedocs.io/en/latest/_modules/cell2location/utils/filtering.html#filter_genes) saving like this does not work.

I am working with a GPU on Azure (NVIDIA Corporation GK210GL [Tesla K80] (rev a1)), CUDA Version: 11.4

These are my package versions: packages in environment at /home/azureuser/miniconda3/envs/cell2loc_azure:

Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu absl-py 1.2.0 pypi_0 pypi aiohttp 3.8.1 pypi_0 pypi aiosignal 1.2.0 pypi_0 pypi anndata 0.8.0 pypi_0 pypi argon2-cffi 21.3.0 pypi_0 pypi argon2-cffi-bindings 21.2.0 pypi_0 pypi asttokens 2.0.5 pypi_0 pypi async-timeout 4.0.2 pypi_0 pypi attrs 22.1.0 pypi_0 pypi backcall 0.2.0 pypi_0 pypi beautifulsoup4 4.11.1 pypi_0 pypi bleach 5.0.1 pypi_0 pypi ca-certificates 2022.07.19 h06a4308_0 cachetools 5.2.0 pypi_0 pypi cell2location 0.1 pypi_0 pypi certifi 2022.6.15 py39h06a4308_0 cffi 1.15.1 pypi_0 pypi charset-normalizer 2.1.0 pypi_0 pypi chex 0.1.3 pypi_0 pypi colorama 0.4.5 pypi_0 pypi commonmark 0.9.1 pypi_0 pypi cycler 0.11.0 pypi_0 pypi debugpy 1.6.2 pypi_0 pypi decorator 5.1.1 pypi_0 pypi defusedxml 0.7.1 pypi_0 pypi dm-tree 0.1.7 pypi_0 pypi docrep 0.3.2 pypi_0 pypi entrypoints 0.4 pypi_0 pypi et-xmlfile 1.1.0 pypi_0 pypi etils 0.6.0 pypi_0 pypi executing 0.9.1 pypi_0 pypi fastjsonschema 2.16.1 pypi_0 pypi flax 0.5.3 pypi_0 pypi fonttools 4.34.4 pypi_0 pypi frozenlist 1.3.0 pypi_0 pypi fsspec 2022.7.0 pypi_0 pypi google-auth 2.9.1 pypi_0 pypi google-auth-oauthlib 0.4.6 pypi_0 pypi grpcio 1.47.0 pypi_0 pypi h5py 3.7.0 pypi_0 pypi idna 3.3 pypi_0 pypi igraph 0.9.11 pypi_0 pypi importlib-metadata 4.12.0 pypi_0 pypi importlib-resources 5.9.0 pypi_0 pypi ipykernel 6.15.1 pypi_0 pypi ipython 8.4.0 pypi_0 pypi ipython-genutils 0.2.0 pypi_0 pypi ipywidgets 7.7.1 pypi_0 pypi jax 0.3.15 pypi_0 pypi jaxlib 0.3.15 pypi_0 pypi jedi 0.18.1 pypi_0 pypi jinja2 3.1.2 pypi_0 pypi joblib 1.1.0 pypi_0 pypi jsonschema 4.8.0 pypi_0 pypi jupyter-client 7.3.4 pypi_0 pypi jupyter-core 4.11.1 pypi_0 pypi jupyterlab-pygments 0.2.2 pypi_0 pypi jupyterlab-widgets 1.1.1 pypi_0 pypi kiwisolver 1.4.4 pypi_0 pypi ld_impl_linux-64 2.38 h1181459_1 leidenalg 0.8.10 pypi_0 pypi libffi 3.3 he6710b0_2 libgcc-ng 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libstdcxx-ng 11.2.0 h1234567_1 llvmlite 0.39.0 pypi_0 pypi markdown 3.4.1 pypi_0 pypi markupsafe 2.1.1 pypi_0 pypi matplotlib 3.5.2 pypi_0 pypi matplotlib-inline 0.1.3 pypi_0 pypi mistune 0.8.4 pypi_0 pypi msgpack 1.0.4 pypi_0 pypi mudata 0.2.0 pypi_0 pypi multidict 6.0.2 pypi_0 pypi multipledispatch 0.6.0 pypi_0 pypi natsort 8.1.0 pypi_0 pypi nbclient 0.6.6 pypi_0 pypi nbconvert 6.5.0 pypi_0 pypi nbformat 5.4.0 pypi_0 pypi ncurses 6.3 h5eee18b_3 nest-asyncio 1.5.5 pypi_0 pypi networkx 2.8.5 pypi_0 pypi notebook 6.4.12 pypi_0 pypi numba 0.56.0 pypi_0 pypi numpy 1.22.4 pypi_0 pypi numpyro 0.10.0 pypi_0 pypi oauthlib 3.2.0 pypi_0 pypi opencv-python 4.6.0.66 pypi_0 pypi openpyxl 3.0.10 pypi_0 pypi openssl 1.1.1q h7f8727e_0 opt-einsum 3.3.0 pypi_0 pypi optax 0.1.3 pypi_0 pypi packaging 21.3 pypi_0 pypi pandas 1.4.3 pypi_0 pypi pandocfilters 1.5.0 pypi_0 pypi parso 0.8.3 pypi_0 pypi patsy 0.5.2 pypi_0 pypi pexpect 4.8.0 pypi_0 pypi pickleshare 0.7.5 pypi_0 pypi pillow 9.2.0 pypi_0 pypi pip 22.1.2 py39h06a4308_0 prometheus-client 0.14.1 pypi_0 pypi prompt-toolkit 3.0.30 pypi_0 pypi protobuf 3.20.1 pypi_0 pypi psutil 5.9.1 pypi_0 pypi ptyprocess 0.7.0 pypi_0 pypi pure-eval 0.2.2 pypi_0 pypi pyasn1 0.4.8 pypi_0 pypi pyasn1-modules 0.2.8 pypi_0 pypi pycparser 2.21 pypi_0 pypi pydeprecate 0.3.2 pypi_0 pypi pygments 2.12.0 pypi_0 pypi pynndescent 0.5.7 pypi_0 pypi pyparsing 3.0.9 pypi_0 pypi pyro-api 0.1.2 pypi_0 pypi pyro-ppl 1.8.1 pypi_0 pypi pyrsistent 0.18.1 pypi_0 pypi python 3.9.12 h12debd9_1 python-dateutil 2.8.2 pypi_0 pypi python-igraph 0.9.11 pypi_0 pypi pytorch-lightning 1.6.5 pypi_0 pypi pytz 2022.1 pypi_0 pypi pyyaml 6.0 pypi_0 pypi pyzmq 23.2.0 pypi_0 pypi readline 8.1.2 h7f8727e_1 requests 2.28.1 pypi_0 pypi requests-oauthlib 1.3.1 pypi_0 pypi rich 11.2.0 pypi_0 pypi rsa 4.9 pypi_0 pypi scanpy 1.9.1 pypi_0 pypi scikit-learn 1.1.1 pypi_0 pypi scipy 1.8.1 pypi_0 pypi scvi-tools 0.17.1 pypi_0 pypi seaborn 0.11.2 pypi_0 pypi send2trash 1.8.0 pypi_0 pypi session-info 1.0.0 pypi_0 pypi setuptools 61.2.0 py39h06a4308_0 six 1.16.0 pypi_0 pypi soupsieve 2.3.2.post1 pypi_0 pypi sqlite 3.38.5 hc218d9a_0 stack-data 0.3.0 pypi_0 pypi statsmodels 0.13.2 pypi_0 pypi stdlib-list 0.8.0 pypi_0 pypi tensorboard 2.9.0 pypi_0 pypi tensorboard-data-server 0.6.1 pypi_0 pypi tensorboard-plugin-wit 1.8.1 pypi_0 pypi tensorstore 0.1.21 pypi_0 pypi terminado 0.15.0 pypi_0 pypi texttable 1.6.4 pypi_0 pypi threadpoolctl 3.1.0 pypi_0 pypi tinycss2 1.1.1 pypi_0 pypi tk 8.6.12 h1ccaba5_0 toolz 0.12.0 pypi_0 pypi torch 1.12.0 pypi_0 pypi torchmetrics 0.9.3 pypi_0 pypi tornado 6.2 pypi_0 pypi tqdm 4.64.0 pypi_0 pypi traitlets 5.3.0 pypi_0 pypi typing-extensions 4.3.0 pypi_0 pypi tzdata 2022a hda174b7_0 umap-learn 0.5.3 pypi_0 pypi urllib3 1.26.11 pypi_0 pypi wcwidth 0.2.5 pypi_0 pypi webencodings 0.5.1 pypi_0 pypi werkzeug 2.2.1 pypi_0 pypi wheel 0.37.1 pyhd3eb1b0_0 widgetsnbextension 3.6.1 pypi_0 pypi xz 5.2.5 h7f8727e_1 yarl 1.7.2 pypi_0 pypi zipp 3.8.1 pypi_0 pypi zlib 1.2.12 h7f8727e_2

I have installed the environment as recommended, these were my commands: export PYTHONNOUSERSITE="literallyanyletters" conda create -y -n cell2loc_azure python=3.9 conda activate cell2loc_azure pip install git+https://github.com/BayraktarLab/cell2location.git#egg=cell2location[tutorials] python -m ipykernel install --user --name=cell2loc_azure --display-name='cell2loc_azure'

Thank you for your help!

vitkl commented 2 years ago

Hi @bednarsky

I think a workaround is simply:

# plot ELBO loss history during training, removing first 100 epochs from the plot
mod.plot_history(5000)
plt.savefig(f"{scvi_run_name}/training_ELBO_history_minus5k.png",
                   bbox_inches='tight')
plt.close()
mod.plot_history(0)
plt.savefig(f"{scvi_run_name}/training_ELBO_history_all.png",
                   bbox_inches='tight')
plt.close()

@yozhikoff do you have any idea why passing ax leads to an error?

yozhikoff commented 2 years ago

Fixed in #194

Also, here's the code if you want to use it manually.

def plot_history(self, iter_start=0, iter_end=-1, ax=None):
    r"""Plot training history
    Parameters
    ----------
    iter_start
        omit initial iterations from the plot
    iter_end
        omit last iterations from the plot
    ax
        matplotlib axis
    """
    if ax is None:
        ax = plt.gca()
    if iter_end == -1:
        iter_end = len(self.history_["elbo_train"])

    ax.plot(
        np.array(self.history_["elbo_train"].index[iter_start:iter_end]),
        np.array(self.history_["elbo_train"].values.flatten())[iter_start:iter_end],
        label="train",
    )
    ax.legend()
    ax.set_xlim(0, len(self.history_["elbo_train"]))
    ax.set_xlabel("Training epochs")
    ax.set_ylabel("-ELBO loss")
    plt.tight_layout()