theislab / ehrapy

Electronic Health Record Analysis with Python.
https://ehrapy.readthedocs.io/
Apache License 2.0
232 stars 19 forks source link

Impossible to calculate UMAP with custom connectivities keys #789

Closed VladimirShitov closed 2 months ago

VladimirShitov commented 2 months ago

Report

Consider the use case when you have several layers with some representations of patients in the data. You want to calculate neighbours and connectivities separately for these layers and then build a UMAP. Naturally, you want to store neighbours and connectivities in custom slots to not lose them. key_added in ep.pp.neighbors and neighbors_key in ep.tl.umap provide such a possibility. Unfortunately, using them atm is not possible without some magical rituals.

Here's a code to reproduce errors. Uncomment all 3 lines to get the code that does the job.

import anndata
import numpy as np
import ehrapy as ep

X = np.random.random((5, 3))  # 5 cells, 3 genes
layers = {
    'layer_1': np.random.random((5, 3)),
    'layer_2': np.random.random((5, 3)),
}

adata_test = anndata.AnnData(X=X, obsm=layers)

ep.pp.neighbors(adata_test, use_rep='layer_1', key_added='layer_1')
# adata_test.uns["neighbors"] = None  # Uncomment to fix
ep.tl.umap(adata_test, neighbors_key='layer_1')

# del adata_test.uns["neighbors"]  # Uncomment to fix

ep.pp.neighbors(adata_test, use_rep='layer_2', key_added='layer_2')
# adata_test.uns["neighbors"] = None  # Uncomment to fix
ep.tl.umap(adata_test, neighbors_key='layer_2')

When running like this, the code will produce the following error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[47], line 17
     14 ep.pp.neighbors(adata_test, use_rep='layer_1', key_added='layer_1')
     16 # adata_test.uns["neighbors"] = None  # Uncomment to fix
---> 17 ep.tl.umap(adata_test, neighbors_key='layer_1')
     19 # del adata_test.uns["neighbors"]  # Uncomment to fix
     20 ep.pp.neighbors(adata_test, use_rep='layer_2', key_added='layer_2')

File /opt/conda/lib/python3.10/site-packages/ehrapy/tools/_scanpy_tl_api.py:162, in umap(adata, min_dist, spread, n_components, maxiter, alpha, gamma, negative_sample_rate, init_pos, random_state, a, b, copy, method, neighbors_key)
     81 def umap(
     82     adata: AnnData,
     83     min_dist: float = 0.5,
   (...)
     96     neighbors_key: Optional[str] = None,
     97 ) -> Optional[AnnData]:  # pragma: no cover
     98     """Embed the neighborhood graph using UMAP [McInnes18]_.
     99 
    100     UMAP (Uniform Manifold Approximation and Projection) is a manifold learning
   (...)
    160         **X_umap** : `adata.obsm` field UMAP coordinates of data.
    161     """
--> 162     if adata.uns["neighbors"] is None or neighbors_key not in adata.uns:
    163         return sc.tl.umap(
    164             adata=adata,
    165             min_dist=min_dist,
   (...)
    178             neighbors_key=neighbors_key,
    179         )
    180     else:

KeyError: 'neighbors'

When the first comment is uncommented:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[49], line 20
     16 ep.tl.umap(adata_test, neighbors_key='layer_1')
     18 # del adata_test.uns["neighbors"]  # Uncomment to fix
---> 20 ep.pp.neighbors(adata_test, use_rep='layer_2', key_added='layer_2')
     21 # adata_test.uns["neighbors"] = None  # Uncomment to fix
     22 ep.tl.umap(adata_test, neighbors_key='layer_2')

File /opt/conda/lib/python3.10/site-packages/ehrapy/preprocessing/_scanpy_pp_api.py:245, in neighbors(adata, n_neighbors, n_pcs, use_rep, knn, random_state, method, metric, metric_kwds, key_added, copy)
    191 def neighbors(
    192     adata: AnnData,
    193     n_neighbors: int = 15,
   (...)
    202     copy: bool = False,
    203 ) -> Optional[AnnData]:  # pragma: no cover
    204     """Compute a neighborhood graph of observations [McInnes18]_.
    205 
    206     The neighbor search efficiency of this heavily relies on UMAP [McInnes18]_,
   (...)
    243          Instead of decaying weights, this stores distances for each pair of neighbors.
    244     """
--> 245     return sc.pp.neighbors(
    246         adata=adata,
    247         n_neighbors=n_neighbors,
    248         n_pcs=n_pcs,
    249         use_rep=use_rep,
    250         knn=knn,
    251         random_state=random_state,
    252         method=method,
    253         metric=metric,
    254         metric_kwds=metric_kwds,
    255         key_added=key_added,
    256         copy=copy,
    257     )

File /opt/conda/lib/python3.10/site-packages/scanpy/neighbors/__init__.py:178, in neighbors(adata, n_neighbors, n_pcs, use_rep, knn, method, transformer, metric, metric_kwds, random_state, key_added, copy)
    176 if adata.is_view:  # we shouldn't need this here...
    177     adata._init_as_actual(adata.copy())
--> 178 neighbors = Neighbors(adata)
    179 neighbors.compute_neighbors(
    180     n_neighbors,
    181     n_pcs=n_pcs,
   (...)
    188     random_state=random_state,
    189 )
    191 if key_added is None:

File /opt/conda/lib/python3.10/site-packages/legacy_api_wrap/__init__.py:80, in legacy_api.<locals>.wrapper.<locals>.fn_compatible(*args_all, **kw)
     77 @wraps(fn)
     78 def fn_compatible(*args_all: P.args, **kw: P.kwargs) -> R:
     79     if len(args_all) <= n_positional:
---> 80         return fn(*args_all, **kw)
     82     args_pos: P.args
     83     args_pos, args_rest = args_all[:n_positional], args_all[n_positional:]

File /opt/conda/lib/python3.10/site-packages/scanpy/neighbors/__init__.py:377, in Neighbors.__init__(self, adata, n_dcs, neighbors_key)
    375     neighbors_key = "neighbors"
    376 if neighbors_key in adata.uns:
--> 377     neighbors = NeighborsView(adata, neighbors_key)
    378     if "distances" in neighbors:
    379         self.knn = issparse(neighbors["distances"])

File /opt/conda/lib/python3.10/site-packages/scanpy/_utils/__init__.py:1036, in NeighborsView.__init__(self, adata, key)
   1033     self._distances = adata.obsp[self._dists_key]
   1035 # fallback to uns
-> 1036 self._connectivities, self._distances = _fallback_to_uns(
   1037     self._neighbors_dict,
   1038     self._connectivities,
   1039     self._distances,
   1040     self._conns_key,
   1041     self._dists_key,
   1042 )

File /opt/conda/lib/python3.10/site-packages/scanpy/_utils/__init__.py:974, in _fallback_to_uns(dct, conns, dists, conns_key, dists_key)
    973 def _fallback_to_uns(dct, conns, dists, conns_key, dists_key):
--> 974     if conns is None and conns_key in dct:
    975         conns = dct[conns_key]
    976     if dists is None and dists_key in dct:

TypeError: argument of type 'NoneType' is not iterable

When the first 2 comments are uncommented, the error is identical to the first case.

Version information

-----
anndata             0.10.8
ehrapy              0.9.0
matplotlib          3.7.4
numpy               1.24.3
optuna              3.5.0
pandas              2.2.2
scanpy              1.10.2
scipy               1.11.4
seaborn             0.13.2
session_info        1.0.0
sklearn             1.2.2
umap                0.5.5
-----
Crypto                                      3.19.0
OpenSSL                                     23.2.0
PIL                                         9.5.0
absl                                        NA
anyio                                       NA
arrow                                       1.2.3
asttokens                                   NA
astunparse                                  1.6.3
attr                                        23.1.0
attrs                                       23.1.0
autograd                                    NA
autograd_gamma                              NA
babel                                       2.12.1
backcall                                    0.2.0
boto3                                       1.26.100
botocore                                    1.33.1
brotli                                      NA
cachetools                                  4.2.4
causallearn                                 NA
certifi                                     2023.11.17
cffi                                        1.15.1
charset_normalizer                          3.2.0
cloud_tpu_client                            0.10
cloudpickle                                 2.2.1
colorama                                    0.4.6
colorlog                                    NA
comm                                        0.1.4
cryptography                                41.0.3
cycler                                      0.12.1
cython_runtime                              NA
cytoolz                                     0.12.2
dab0eaeee8bfae79490a0d4f23f5ad820bb199d8    NA
dask                                        2023.12.0
dateutil                                    2.8.2
debugpy                                     1.6.7.post1
decorator                                   5.1.1
defusedxml                                  0.7.1
deprecated                                  1.2.14
dill                                        0.3.7
dot_parser                                  NA
dowhy                                       0.11.1
entrypoints                                 0.4
etils                                       1.4.1
exceptiongroup                              1.1.3
executing                                   1.2.0
fastjsonschema                              NA
fhiry                                       4.0.0
filelock                                    3.12.2
flatbuffers                                 23.5.26
formulaic                                   1.0.2
fqdn                                        NA
fsspec                                      2023.12.2
future                                      0.18.3
gast                                        NA
geopandas                                   0.14.1
google                                      NA
google_auth_httplib2                        NA
googleapiclient                             NA
graphlib                                    NA
greenlet                                    2.0.2
grpc                                        1.60.0
grpc_status                                 NA
h5py                                        3.9.0
httplib2                                    0.21.0
idna                                        3.4
igraph                                      0.11.3
imblearn                                    0.11.0
importlib_metadata                          NA
importlib_resources                         NA
interface_meta                              1.3.0
ipykernel                                   6.25.1
ipython_genutils                            0.2.0
ipywidgets                                  7.7.1
isoduration                                 NA
jax                                         0.4.21
jaxlib                                      0.4.21
jedi                                        0.19.0
jinja2                                      3.1.4
jmespath                                    1.0.1
joblib                                      1.3.2
json5                                       NA
jsonpointer                                 2.0
jsonschema                                  4.19.0
jsonschema_specifications                   NA
jupyter_events                              0.9.0
jupyter_server                              2.12.1
jupyterlab_server                           2.24.0
kaggle_gcp                                  NA
kaggle_secrets                              NA
kaggle_web_client                           NA
keras                                       2.13.1
kiwisolver                                  1.4.5
lamin_utils                                 0.13.2
legacy_api_wrap                             NA
leidenalg                                   0.10.2
lifelines                                   0.29.0
llvmlite                                    0.40.1
log                                         NA
lz4                                         4.3.2
markupsafe                                  2.1.3
matplotlib_inline                           0.1.6
missingno                                   0.5.2
ml_dtypes                                   0.3.1
mmh3                                        NA
mpl_toolkits                                NA
mpmath                                      1.3.0
natsort                                     8.4.0
nbformat                                    5.9.2
networkx                                    3.1
numba                                       0.57.1
numexpr                                     2.8.8
oauth2client                                4.1.3
opentelemetry                               NA
opt_einsum                                  v3.3.0
overrides                                   NA
packaging                                   21.3
parso                                       0.8.3
patsy                                       0.5.6
pexpect                                     4.8.0
pickleshare                                 0.7.5
pkg_resources                               NA
platformdirs                                4.1.0
plotly                                      5.16.1
prodict                                     NA
prometheus_client                           NA
prompt_toolkit                              3.0.39
proto                                       NA
psutil                                      5.9.3
ptyprocess                                  0.7.0
pure_eval                                   0.2.2
pyarrow                                     14.0.1
pyasn1                                      0.5.0
pyasn1_modules                              0.3.0
pycparser                                   2.21
pydev_ipython                               NA
pydevconsole                                NA
pydevd                                      2.9.5
pydevd_file_utils                           NA
pydevd_plugins                              NA
pydevd_tracing                              NA
pydot                                       1.4.2
pygments                                    2.16.1
pynndescent                                 0.5.11
pyparsing                                   3.1.1
pyproj                                      3.6.1
pythonjsonlogger                            NA
pytz                                        2023.3
rapidfuzz                                   3.5.2
referencing                                 NA
requests                                    2.31.0
rfc3339_validator                           0.1.4
rfc3986_validator                           0.1.1
rich                                        NA
rpds                                        NA
rsa                                         4.9
rtree                                       1.1.0
send2trash                                  NA
setuptools                                  68.1.2
setuptools_scm                              NA
shapely                                     1.8.5.post1
simplejson                                  3.19.2
sitecustomize                               NA
six                                         1.16.0
sniffio                                     1.3.0
socks                                       1.7.1
sqlalchemy                                  2.0.20
stack_data                                  0.6.2
statsmodels                                 0.14.2
sympy                                       1.12
tableone                                    0.9.1
tabulate                                    0.9.0
tblib                                       3.0.0
tensorboard                                 2.13.0
tensorflow                                  2.13.0
tensorflow_probability                      0.21.0
termcolor                                   NA
texttable                                   1.7.0
thefuzz                                     0.22.1
threadpoolctl                               3.2.0
timeago                                     1.0.14
tlz                                         0.12.2
tomli                                       2.0.1
toolz                                       0.12.0
torch                                       2.0.0+cpu
tornado                                     6.3.3
tqdm                                        4.66.1
traitlets                                   5.9.0
tree                                        0.1.8
typing_extensions                           NA
uri_template                                NA
uritemplate                                 3.0.1
urllib3                                     1.26.15
wcwidth                                     0.2.6
webcolors                                   1.13
websocket                                   1.6.2
wrapt                                       1.15.0
xxhash                                      NA
yaml                                        6.0.1
zipp                                        NA
zmq                                         24.0.1
zoneinfo                                    NA
zstandard                                   0.22.0
-----
IPython             8.14.0
jupyter_client      7.4.9
jupyter_core        5.3.1
jupyterlab          4.0.9
notebook            6.5.5
-----
Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0]
Linux-5.15.154+-x86_64-with-glibc2.31
-----
Session information updated at 2024-08-22 11:44
eroell commented 2 months ago

Wow what a great description of the issue thanks so much for reporting @VladimirShitov ! Got this bug really on spot! PR on the way fixing this!

VladimirShitov commented 2 months ago

Haha, I just had to find a workaround to work with my data, and ChatGPT nicely generated the toy test. I was happy to contribute!