theislab / cellrank

CellRank: dynamics from multi-view single-cell data
https://cellrank.org
BSD 3-Clause "New" or "Revised" License
347 stars 46 forks source link

Embedding projection only works sometimes #603

Closed Marius1311 closed 3 years ago

Marius1311 commented 3 years ago

The embedding projection of transition matrices via kernel.compute_projection produces unexpected results for some kernels. Our idea for this method was to have something similar to scVelo's visualization of velocities in a 2D embedding, i.e. a UMAP or t-SNE. So we re-implemented this and it works fine for the VelocityKernel, basically gives the same picture we get if we use scVelo for the plotting.

To produce a projection of a transition matrix via the kernel interface, you just call

kernel.compute_projection() 
scv.pl.velocity_embedding_stream(adata, vkey='T_fwd') # for streamlines, can use the other scVelo functions here as well

and this should work for any kernel.

The trouble starts when we use kernels other than the VelocityKernel, see below for e.g. the external WOTKernel. In Fig. 1, I'm showing the trajectories, which look fine, they go from the early days (day 0) to the late days (day 18), where they terminate. However, when I'm computing the embedding projection, I get the arrows from Fig. 2.

The same thing happens for the PseudotimeKernel, however, only for the hard threshold scheme, see Figs. 3 and 4 below for hard and soft thresholding schemes, respectively. For the soft threshold scheme, the projected arrows (streamlines) look fine and reflect the expected direction. For the hard threshold scheme, the projection is totally wrong. I would like to emphasize that I don't think the actual transition matrix is wrong - when I use the transition matrix to plot random walks, starting in the middle (Fev+ cluster) of the UMAP, they go into the right direction and terminate in the right clusters (e.g. Beta). See Fig. 5 for the simulated random walks.

So something is going wrong here when we use the hard threshold scheme. Note that in the hard threshold scheme, we actually remove some graph edges whereas in the soft scheme, we just re-weigh them, but we always keep all graph edges. This means that the sparsity pattern in the adjacency matrix A is symmetric in the soft thresholding scheme, but not in the hard thresholding scheme. Similarity, the sparsity pattern is not symmetric in the WOT kernel, but it is symmetric in the Velocity Kernel. This could be part of the problem.

Projection of the transition matrix into low dimensions isn't trivial, this has already been discussed in the original RNA Velocity publication (La Manno et al, Nature 2018) in Supplementary Note 2, "Section 11. Uncertainty and limitations of neighborhood-based velocity projections". This would be a good starting point to look into this - I think we somehow have to be a bit careful with local cell density when projecting into low dimensions, but I don't really understand it yet.

Fig. 1

Trajectories on reprogramming data from Schiebinger et al., Cell 2019 image

Fig. 2

Projected velocity streams for reprogramming data. image Now these look very weird! I would expect the streamlines to point from earlier to later days.

Fig. 3

Pancreas data, soft threshold scheme image

Fig. 4

Pancreas data, hard threshold scheme image

Fig. 5

Pancreas data, simulated random walks image Starting points denoted by black dots, endpoint denoted by yellow dots. Showing 5 simulations. Random walks visualized through lines, colored according to simulation time from black (early) via purple to yellow (late).

Versions:

...

Marius1311 commented 3 years ago

Any updates on this yet @WeilerP ?

WeilerP commented 3 years ago

Any updates on this yet @WeilerP ?

No sorry, didn't have the time to take a look at it yet. Will do and get back to you ASAP (it's on my TODO list).

Marius1311 commented 3 years ago

@WeilerP, I think we got it: we need to make sure we don't mess with the estimate of local density, i.e. the 1/n_i term. We need to make sure that for every kernel, this term corresponds to the actual KNN graph density, not to our final transition matrix, which may have removed edges etc.

I'm excited to hear whether this solves the problem with the hard PseudoTimeKernel - that would tell us we're on the right track and we can use the same solution for the WOT kernel.

michalk8 commented 3 years ago

@WeilerP @Marius1311 Can confirm that it was the case @WeilerP found, here's modified code snippet that works (haven't debugged it yet for corner cases/maybe there's cleaner impl. + maybe some check for connectivities [that all kernels use the same?]):

        start = logg.info(f"Projecting transition matrix onto `{basis}`")
        emb = _get_basis(self.adata, basis)
        T = self.transition_matrix
        T.sort_indices()
        T_emb = np.empty_like(emb)
        conn = self.kernels[0]._conn
        conn.sort_indices()

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i, row in enumerate(T):
                cixs = conn[i].indices
                probs = np.zeros_like(cixs, dtype=T.dtype)
                ixs = np.where(np.isin(cixs, row.indices))[0]
                probs[ixs] = row.data

                dX = emb[conn[i].indices] - emb[i, None]
                if np.any(np.isnan(dX)):
                    T_emb[i] = np.nan
                else:
                    dX /= np.linalg.norm(dX, axis=1)[:, None]
                    dX = np.nan_to_num(dX)
                    T_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0)

Code to produce below figures:

import cellrank as cr
import scvelo as scv

adata = cr.datasets.pancreas_preprocessed()
k = cr.tl.kernels.PseudotimeKernel(adata).compute_transition_matrix("hard", nu=0.5, b=20, frac_to_keep=0)
k.compute_projection()
scv.pl.velocity_embedding_stream(adata, vkey='T_fwd', basis='umap', dpi=200)#, smooth=1)

Soft: soft Hard: hard

WeilerP commented 3 years ago

@michalk8, thanks! Just to summarize: As of La Manno, et al., the predicted velocity displacement of a cell is given by

Screenshot 2021-07-04 at 12 06 45 PM

In the current implementation, here,

Using the PseudoTimeKernel with a variation of @michalk8's proposal, namely

        emb = _get_basis(self.adata, basis)
        T_emb = np.empty_like(emb)

        T = self.transition_matrix
        T.sort_indices()

        conn = self.kernels[0]._conn
        conn.sort_indices()

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i, row in enumerate(T):
                conn_idxs = conn[i, :].indices
                dX = emb[conn_idxs] - emb[i, None]

                probs = np.zeros_like(conn_idxs, dtype=T.dtype)
                probs[np.isin(conn_idxs, row.indices)] = row.data

                if np.any(np.isnan(dX)):
                    T_emb[i] = np.nan
                else:
                    dX /= np.linalg.norm(dX, axis=1)[:, None]
                    dX = np.nan_to_num(dX)
                    T_emb[i] = probs.dot(dX) - dX.sum(0) / dX.shape[0]

the embedding projection for the hard thresholding scheme with default parameters looks as follows:

ptk_hard

There are two issues left, IMO:

  1. The non-trivial entries of PseudoTimeKernel.transition_matrix are flipped, i.e.
ptk = PseudotimeKernel(adata)
ptk.compute_transition_matrix()
transition_matrix[0, :].indices

returns indices in decreasing order. However, after using

ptk = PseudotimeKernel(adata)
ptk.compute_transition_matrix()
ptk.compute_projection()

with above mentioned code, the order is correct. The entries should be in correct order in the first place, no? Also, running ptk.compute_projection() should not tamper with the transition matrix (inplace), IMO. I think this should be fixed in a separate issue first.

  1. The number of neighbors considered for the hard thresholding scheme is not cell specific (see here), i.e. not the number of neighbors in the symmetrized connectivity matrix. This leads to different transition matrices for the PseudoTimeKernel with hard thresholding scheme and frac_to_keep=1 compared to ConnectivityKernel. The following assertion fails:
import cellrank as cr
import scvelo as scv

adata = cr.datasets.pancreas_preprocessed()
ptk = cr.tl.kernels.PseudotimeKernel(adata).compute_transition_matrix("hard", frac_to_keep=1)
ck = cr.tl.kernels.ConnectivityKernel(adata).compute_transition_matrix()

assert ptk.transition_matrix.getnnz(axis=1).mean() == ck.transition_matrix.getnnz(axis=1).mean()

This should be fixed first, as well, IMO.

michalk8 commented 3 years ago

MO. I think this should be fixed in a separate issue first.

Agree it should not modify the tmat, it was just to quickly test. I'd say either ensure that connectivities and biased connectivities are always sorted and sort them in _read_from_adata and bias_knn respectively.

PseudoTimeKernel with hard thresholding scheme and frac_to_keep=1 compared to ConnectivityKernel

Think @Marius1311 can tell you more about why it's done this way, we've had similar discussion in https://github.com/theislab/cellrank/issues/530. Also, the threshold is clipped here to stay close to Palantir.

Marius1311 commented 3 years ago

Let's discuss this today in a call, see mattermost. The main points in my opinion are

Feel free to add/modify. BTW, great work both of you, thanks for really diving into this!

Marius1311 commented 3 years ago

Outcomes

WeilerP commented 3 years ago

I checked the PseudotimeKernel with a custom scheme which uses the symmetrized kNN matrix:

def callback(
    cell_pseudotime: float,
    neigh_pseudotime: np.ndarray,
    neigh_conn: np.ndarray,
    fraction: float,
):
    ixs = np.flip(np.argsort(neigh_conn))
    n_neighs = len(ixs)
    k_thresh = max(0, int(np.floor(n_neighs * fraction)))
    close_ixs, far_ixs = ixs[:k_thresh], ixs[k_thresh:]

    mask_keep = cell_pseudotime <= neigh_pseudotime[far_ixs]
    far_ixs_keep = far_ixs[mask_keep]

    biased_conn = np.zeros_like(neigh_conn)
    biased_conn[close_ixs] = neigh_conn[close_ixs]
    biased_conn[far_ixs_keep] = neigh_conn[far_ixs_keep]

    return biased_conn

Using frac_to_keep=0.3 and fraction=0.3, the results are almost identical

Current implementation

current_implementation

Custom scheme using symmetrized kNN graph

custom_scheme

Comparing the custom scheme using fraction=1 with the ConnectivityKernel gives slightly different results:

Custom scheme using symmetrized kNN graph w/ fraction=1

custom_scheme

Connectivity kernel

connectivity_kernel

In a nutshell

Also, @michalk8 why do we consider at most 30 neighbors for the hard thresholding scheme?

https://github.com/theislab/cellrank/blob/395bdccaa3509594818d5ce99d030cb9394d4b52/cellrank/tl/kernels/_pseudotime_schemes.py#L167

Is this to mimic the original Palantir implementation?

Marius1311 commented 3 years ago

Thanks @WeilerP!

With your second point, I imagine you refer to the difference between your custom implementation with fract = 1 and the ConnectivityKernel. Could you check whether density normalization was set to True for the ConnectivityKernel? That could explain the differences.

WeilerP commented 3 years ago

With your second point, I imagine you refer to the difference between your custom implementation with fract = 1 and the ConnectivityKernel. Could you check whether density normalization was set to True for the ConnectivityKernel? That could explain the differences.

Yes, the difference was w.r.t custom implementation with fraction=1 and the ConnectivityKernel. I reran the ConnectivityKernel using

ck = ConnectivityKernel(adata).compute_transition_matrix(density_normalize=False)
ck.compute_projection()
scv.pl.velocity_embedding_stream(adata, vkey='T_fwd')

and the two now match:

PseudoTimeKernel w/ custom callback as defined here

pseudo_time_kernel

ConnectivityKernel

ConnectivityKernel
WeilerP commented 3 years ago

I also played around with the WOTKernel

Current implementation

Using the current implementation the projection is wrong as you already mentioned in the issue description.

original

Without random uniform correction

As we do not use a connectivity matrix in the context of WOT, we can (or should) change

https://github.com/theislab/cellrank/blob/0d8747db8777864161d543f6c253660557753f8f/cellrank/tl/kernels/_base_kernel.py#L311

to

T_emb[i] = probs.dot(dX)

which already improves the results.

wo random uniform correction

Weighted by inverse connectivities

To incorporate some correction, we can scale the embedding vectors by the connectivity of the two used observations, i.e.

        from scvelo.tools.velocity_embedding import quiver_autoscale

        if self._transition_matrix is None:
            raise RuntimeError(
                "Compute transition matrix first as `.compute_transition_matrix()`."
            )

        start = logg.info(f"Projecting transition matrix onto `{basis}`")
        emb = _get_basis(self.adata, basis)
        T_emb = np.empty_like(emb)

        q = np.asarray(self._conn.sum(axis=0)).squeeze()

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i, row in enumerate(self.transition_matrix):
                dX = emb[row.indices] - emb[i, None]
                if np.any(np.isnan(dX)):
                    T_emb[i] = np.nan
                else:
                    dX /= np.linalg.norm(dX, axis=1)[:, None]
                    dX = np.nan_to_num(dX)
                    probs = row.data
                    connectivity_weights = q[i] * q[row.indices]w
                    T_emb[i] = (probs / connectivity_weights).dot(dX)
weighted by inverse connectivity

Which look similar to the previous result, I would say.


This approach of weighting by the inverse sum of connectivities does not work well on the PseudoTimeKernel (hard threshold). The following results already use the correct number of samples, i.e. the implementation from here. Things looked worse with the current implementation.

Only weighting

This uses

T_emb[i, :] = (probs / connectivity_weights).dot(dX)
weighted

weighted + random uniform unweighted correction

T_emb[i, :] = (probs / connectivity_weights).dot(dX) - dX.sum(0) / dX.shape[0]
weighted_and_uniform_random_correction_wo_weight

weighted + weighted uniform random correction

T_emb[i, :] = (probs / connectivity_weights).dot(dX) - (dX / connectivity_weights.reshape(-1, 1)).sum(0) / dX.shape[0]
weighted_w_uniform_random_weighted_correction

or using

k = cr.tl.kernels.PseudotimeKernel(adata).compute_transition_matrix("hard", nu=0.5, b=20, frac_to_keep=0)
k.compute_projection()
scv.pl.velocity_embedding_stream(adata, vkey='T_fwd', basis='umap')

as @michalk8 suggests here:

Screenshot 2021-07-09 at 11 56 41 AM

These last two results look very similar to the cases without weighting.

Marius1311 commented 3 years ago

Yes, the difference was w.r.t custom implementation with fraction=1 and the ConnectivityKernel. I reran the ConnectivityKernel using

ck = ConnectivityKernel(adata).compute_transition_matrix(density_normalize=False)
ck.compute_projection()
scv.pl.velocity_embedding_stream(adata, vkey='T_fwd')

and the two now match:

Nice, we figured this one out!

Marius1311 commented 3 years ago

Okay, so I take from this that we haven't really figured it out for the WOT kernel yet. However, we know how to do it for all KNN-based transition matrices. To prevent getting lost, I would suggest we implement the fix for the embedding projection we found for the hard threshold kernel, discuss the hard thresholding scheme in a separate issue (number of neighbors considered), and also discuss embedding projection for non-KNN-based kernels in a separate issues as well. What's your opinion on this @WeilerP ?

WeilerP commented 3 years ago

Okay, so I take from this that we haven't really figured it out for the WOT kernel yet. However, we know how to do it for all KNN-based transition matrices. To prevent getting lost, I would suggest we implement the fix for the embedding projection we found for the hard threshold kernel, discuss the hard thresholding scheme in a separate issue (number of neighbors considered), and also discuss embedding projection for non-KNN-based kernels in a separate issues as well. What's your opinion on this @WeilerP ?

Sounds exactly what I would do! I'll open a PR fixing the embedding projection ASAP.

michalk8 commented 3 years ago

closed via #665

Marius1311 commented 3 years ago

Amazing, thanks everyone.