theislab / cellrank

CellRank: dynamics from multi-view single-cell data
https://cellrank.org
BSD 3-Clause "New" or "Revised" License
347 stars 46 forks source link

Allow the incorporation of genes in adata.raw.X for detection of driver genes #870

Closed JesseRop closed 2 years ago

JesseRop commented 2 years ago

Kindly advise on how to include all high quality genes i.e those in adata.raw.X for the estimation of driver genes or include a feature to support this.

In the workflow below, I'm not able to include all genes in the estimation of driver genes when I use adata.raw.to_adata() as the expression matrix.

I have tried the following work flow but it seems to error out at the ck_raw = ConnectivityKernel(adata.raw.to_adata()).compute_transition_matrix() probably due to adata.raw.to_adata() throwing out the .obsp slot

Kindly advice on the best way to include all the genes in detections of drivers. Thanks

adata = cr.datasets.pancreas()

scv.pp.normalize_per_cell(adata)

sc.pp.log1p(adata)

scv.pp.filter_genes_dispersion(adata, subset=False)

sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)

adata.raw = adata

adata

adata.raw.to_adata()

scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000, subset_highly_variable=False)
sc.tl.pca(adata)
sc.pp.neighbors(adata, n_pcs=30, n_neighbors=30)
scv.pp.moments(adata, n_pcs=None, n_neighbors=None)

scv.tl.velocity(adata, mode="stochastic")
scv.tl.velocity_graph(adata)

from cellrank.tl.kernels import VelocityKernel

vk = VelocityKernel(adata)

from cellrank.tl.kernels import ConnectivityKernel

ck_raw = ConnectivityKernel(adata.raw.to_adata()).compute_transition_matrix()

combined_kernel = 0.8 * vk + 0.2 * ck
print(combined_kernel)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/envs/scv_jupcon/lib/python3.9/site-packages/cellrank/ul/_utils.py:138, in _get_neighs(adata, mode, key)
    137 try:
--> 138     res = _read_graph_data(adata, key)
    139     assert isinstance(res, (np.ndarray, spmatrix))

File ~/envs/scv_jupcon/lib/python3.9/site-packages/cellrank/ul/_utils.py:177, in _read_graph_data(adata, key)
    175     return adata.obsp[key]
--> 177 raise KeyError(f"Unable to find data in `adata.obsp[{key!r}]`.")

KeyError: "Unable to find data in `adata.obsp['connectivities']`."

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
Input In [737], in <cell line: 15>()
     11 vk = VelocityKernel(adata)
     13 from cellrank.tl.kernels import ConnectivityKernel
---> 15 ck_raw = ConnectivityKernel(adata.raw.to_adata()).compute_transition_matrix()
     18 combined_kernel = 0.8 * vk + 0.2 * ck
     19 print(combined_kernel)

File ~/envs/scv_jupcon/lib/python3.9/site-packages/cellrank/tl/kernels/_connectivity_kernel.py:41, in ConnectivityKernel.__init__(self, adata, conn_key, check_connectivity)
     35 def __init__(
     36     self,
     37     adata: AnnData,
     38     conn_key: str = "connectivities",
     39     check_connectivity: bool = False,
     40 ):
---> 41     super().__init__(
     42         adata,
     43         conn_key=conn_key,
     44         check_connectivity=check_connectivity,
     45     )

File ~/envs/scv_jupcon/lib/python3.9/site-packages/cellrank/tl/kernels/_base_kernel.py:490, in Kernel.__init__(self, adata, **kwargs)
    488 self._adata = adata
    489 self._n_obs = adata.n_obs
--> 490 self._read_from_adata(**kwargs)

File ~/envs/scv_jupcon/lib/python3.9/site-packages/cellrank/tl/kernels/_mixins.py:27, in ConnectivityMixin._read_from_adata(self, conn_key, check_connectivity, **kwargs)
     25 # fmt: off
     26 self._conn_key = conn_key
---> 27 self._conn = _get_neighs(self.adata, mode="connectivities", key=conn_key)
     28 self._conn = csr_matrix(self._conn).astype(np.float64, copy=False)
     29 # fmt: on

File ~/envs/scv_jupcon/lib/python3.9/site-packages/cellrank/ul/_utils.py:141, in _get_neighs(adata, mode, key)
    139         assert isinstance(res, (np.ndarray, spmatrix))
    140     except (KeyError, AssertionError):
--> 141         res = _read_graph_data(adata, f"{_modify_neigh_key(key)}_{mode}")
    143 if not isinstance(res, (np.ndarray, spmatrix)):
    144     raise TypeError(
    145         f"Expected to find `numpy.ndarray` or `scipy.sparse.spmatrix`, found `{type(res)}`."
    146     )

File ~/envs/scv_jupcon/lib/python3.9/site-packages/cellrank/ul/_utils.py:177, in _read_graph_data(adata, key)
    174 if key in adata.obsp:
    175     return adata.obsp[key]
--> 177 raise KeyError(f"Unable to find data in `adata.obsp[{key!r}]`.")

KeyError: "Unable to find data in `adata.obsp['neighbors_connectivities']`."
michalk8 commented 2 years ago

Hi @JesseRop ,

you can use g.compute_drivers(..., use_raw=True) and it will use adata.raw for the driver computation - you shouldn't really pass the .raw object to the kernel, as adata.raw usually contains filtered/normalized, but not scaled (or otherwise further processed) data; please see an example snippet below. Furthermore, scv.pp.filter_genes_dispersion expects not log-normalized data, scv.pp.filter_and_normalize is what I usually prefer for quick filtering/normalization.

import scanpy as sc
import scvelo as scv
import cellrank as cr
from cellrank.tl.kernels import VelocityKernel, ConnectivityKernel
from cellrank.tl.estimators import GPCCA

adata = cr.datasets.pancreas()

scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000, subset_highly_variable=False)
adata.raw = adata.copy()
adata[:, adata.var['highly_variable']].copy()

sc.tl.pca(adata)
sc.pp.neighbors(adata, n_pcs=30, n_neighbors=30)
scv.pp.moments(adata, n_pcs=None, n_neighbors=None)

scv.tl.velocity(adata, mode="stochastic")
scv.tl.velocity_graph(adata)

vk = VelocityKernel(adata).compute_transition_matrix()
ck = ConnectivityKernel(adata).compute_transition_matrix()
combined_kernel = 0.8 * vk + 0.2 * ck

g = GPCCA(combined_kernel)
g.compute_schur()
g.compute_macrostates(n_states=3, cluster_key="clusters")
g.set_terminal_states_from_macrostates()
g.compute_absorption_probabilities()

# important: `use_raw=True`
df = g.compute_lineage_drivers(lineages="Alpha", use_raw=True)
# `df` is also present in `adata.raw.varm['terminal_lineage_drivers']`

g.plot_lineage_drivers("Alpha", use_raw=True)
Marius1311 commented 2 years ago

Thanks @michalk8, I assume this is solved and I'm closing the issue.

JesseRop commented 2 years ago

Yess! this is solved. Thanks @Marius1311 and @michalk8