KrishnaswamyLab / scprep

A collection of scripts and tools for loading, processing, and handling single cell data.
MIT License
73 stars 19 forks source link

Add support for Diffusion Pseudotime in `scprep.run` #75

Open dburkhardt opened 4 years ago

dburkhardt commented 4 years ago

To do:

dburkhardt commented 4 years ago

Code sketch for Diffusion Pseudotime class

class DiffusionPseudotime():

    def fit(data)
        G = gt.Graph(data, n_pca=100, use_pygsp=True)

        # Calculate eigenvectors of the diffusion operator
        W, V = np.linalg.eig(G.diff_op.toarray())

        # Remove first eigenspace
        V_tilde = G.diff_op.toarray() - (V[:,0] @ V[:,0].T)

        # Calculate M
        T = V_tilde
        I = np.eye(T.shape[1])
        M = np.linalg.inv(I - T) - I

        # Calc DPT
        DPT = squareform(pdist(M))

        self.DPT = DPT
        return DPT

    def transform(root_cell):
        self.pseudotime = self.DPT[root_cell]
        return self.pseudotime

    def calculate_inverse_pseudotime(end_cell):
        self.inv_pseudotime = self.DPT[end_cell]
        return self.inv_pseudotime

    def fit_transform(data, root_cell):
        self.fit(data)
        return self.transform(root_cell)
dburkhardt commented 4 years ago

The more I think about this, the more I realize it might make more sense not to implement this ourselves.

Running DPT from scanpy is very fast and straightforward:

import anndata, scanpy
import numpy as np

data = np.random.normal(size=(3000,5000))

adata = anndata.AnnData(data)

adata.uns['iroot'] = 2977

scanpy.pp.neighbors(adata)
scanpy.tl.diffmap(adata)
scanpy.tl.dpt(adata)

dpt = adata.obs['dpt_pseudotime']

Takes a only few seconds. Implementing this ourselves seems not ideal it scanpy does a good job with it. We could maybe provide a convenience wrapper in scprep.run. Thoughts @scottgigante ?

scottgigante commented 4 years ago

I'm happy for us to include a wrapper to the scanpy method. Maybe something like this:

from ._lazyload import scanpy, anndata

@utils._with_pkg("scanpy")
def diffusion_pseudotime(data, root):
  if not isinstance(root, numbers.Integral) and isinstance(data, pd.DataFrame):
    root = np.argwhere(data.index.values == root)[0,0]
  data = anndata.AnnData(data)
  data.uns['iroot'] = root
  scanpy.pp.neighbors(data)
  scanpy.tl.diffmap(data)
  scanpy.tl.dpt(data)
  return data.obs['dpt_pseudotime']
scottgigante commented 4 years ago

I would expect the return type to be a pd.Series if the input is a dataframe.

dburkhardt commented 4 years ago

Great, for slingshot, this code will run the method, I'm still working on figuring out the best way to extract the information from slingshot for gene orderings

install.packages("BiocManager")
BiocManager::install("slingshot")

library(slingshot)
library(SingleCellExperiment)

# Load counts, PHATE, and cluster labels
counts <- as.data.frame(read.csv('~/scRNAseq/Treutlein.expression.csv'))
phate = as.matrix(read.csv('/home/dan/scRNAseq/Treutlein.phate.csv'))
pclusters = read.csv('/home/dan/scRNAseq/Treutlein.phate_clusters.csv')

# Create SingleCellExperiment 
# How am I actually supposed to have the column names not be passed as genes??
sim <- SingleCellExperiment(assays = List(counts = t(as.matrix(counts[,2:2001]))))

# Add dim red data and clusters to SCE
reducedDims(sim) <- SimpleList(PHATE=phate)
colData(sim)$pclusters <- pclusters[,2]

# Optional, plot data and clusters
library(RColorBrewer)
plot(phate, col = brewer.pal(5,"Set1")[sim$pclusters], pch=16, asp = 1)

# Do Slingshot
sce <- slingshot(sim, clusterLabels = 'pclusters', reducedDim = 'PHATE')

summary(sce$slingPseudotime_3)

# Plot Slingshot
colors <- colorRampPalette(brewer.pal(11,'Spectral')[-6])(100)
plotcol <- colors[cut(sce$slingPseudotime_1, breaks=100)]

# For some reason not all points are plotted here
plot(reducedDims(sce)$PHATE, col = plotcol, pch=16, asp = 1)
lines(SlingshotDataSet(sce), lwd=2, col='black')

# This plots the 'scaffold'
plot(reducedDims(sce)$PHATE, col = brewer.pal(5,"Set1")[sim$pclusters], pch=16, asp = 1)
lines(SlingshotDataSet(sce), lwd=2, type = 'lineages', col = 'black')

#You can get the orderings of the points from these variables
sce$slingPseudotime_1
sce$slingPseudotime_2