rcastelo / GSVA

Gene set variation analysis
200 stars 40 forks source link

Speed up ssGSEA and reduce memory by moving .fastRndWalk to C++ #71

Closed rpolicastro closed 6 months ago

rpolicastro commented 1 year ago

Hello,

For ssGSEA on scRNA-seq data it appears the code is running the .fastRndWalk function n_cells * gene_sets number of times. I was curious whether moving this function to C++ could speed up this operation (and potentially make it more memory efficient) so I roughly reimplemented the function using Rcpp with a few minor changes.

Hijacking your vignette code to make some example data.

library("Rcpp")
library("GSVA")
library("BiocParallel")

set.seed(100)

p <- 20000 ## number of genes
n <- 100    ## number of samples (cells)
## simulate expression values from a standard Gaussian distribution
X <- matrix(rnorm(p*n), nrow=p,
            dimnames=list(paste0("g", 1:p), paste0("s", 1:n)))

X <- as(X, "CsparseMatrix")

## sample gene set sizes
gs <- as.list(sample(10:100, size=100, replace=TRUE))
## sample gene sets
gs <- lapply(gs, function(n, p)
                   paste0("g", sample(1:p, size=n, replace=FALSE)), p)
names(gs) <- paste0("gs", 1:length(gs))

Preparing the data to run the old and new functions.

X <- GSVA:::.filterFeatures(X, "ssgsea")

geneSets <- GSVA:::.mapGeneSetsToFeatures(gs, rownames(X))

n <- ncol(X)

R <- t(sparseMatrixStats::colRanks(X, ties.method = "average"))
mode(R) <- "integer"

Ra <- abs(R)^0.25

The R implementation of .fastRndWalk.

.fastRndWalk <- function(gSetIdx, geneRanking, j, Ra) {
    n <- length(geneRanking)
    k <- length(gSetIdx)
    idxs <- sort.int(match(gSetIdx, geneRanking))

    stepCDFinGeneSet2 <- 
        sum(Ra[geneRanking[idxs], j] * (n - idxs + 1)) /
        sum((Ra[geneRanking[idxs], j]))    

    stepCDFoutGeneSet2 <- (n * (n + 1) / 2 - sum(n - idxs + 1)) / (n - k)

    walkStat <- stepCDFinGeneSet2 - stepCDFoutGeneSet2

    walkStat
}

R_fastRndWalk <- function(){
  es <- bplapply(as.list(1:n), function(j) {
    geneRanking <- order(R[, j], decreasing=TRUE)
    es_sample <- lapply(geneSets, .fastRndWalk, geneRanking, j, Ra)

    unlist(es_sample)
  }, BPPARAM=SerialParam(progressbar=TRUE))
  es <- do.call("cbind", es)
  return(es)
}

Here's the Rcpp implementation of fasterRndWalk.

sourceCpp(code="
  #include <Rcpp.h>
  using namespace Rcpp;

  // [[Rcpp::export]]
  double fasterRndWalk(IntegerVector gSetIdx, IntegerVector geneRanking, int j, NumericMatrix Ra) {
    int n = geneRanking.size();
    int k = gSetIdx.size();
    IntegerVector idxs = match(gSetIdx, geneRanking) - 1;

    double sum1 = 0;
    double sum2 = 0;
    for (int i = 0; i < k; ++i) {
      int idx = idxs[i];
      double value = Ra(geneRanking[idx] - 1, j - 1);
      sum1 += value * (n - idx);
      sum2 += value;
    }

    double stepCDFinGeneSet2 = sum1 / sum2;
    double stepCDFoutGeneSet2 = (n * (n + 1) / 2 - sum(n - idxs + 1)) / (n - k);
    double walkStat = stepCDFinGeneSet2 - stepCDFoutGeneSet2;

    return walkStat;
  }
")

Rcpp_fasterRndWalk <- function() {
  es <- bplapply(as.list(1:n), function(j) {
    geneRanking <- order(R[, j], decreasing=TRUE)
    es_sample <- lapply(geneSets, fasterRndWalk, geneRanking, j, Ra)

    unlist(es_sample)
  }, BPPARAM=SerialParam(progressbar=TRUE))
  es <- do.call("cbind", es)
  return(es)
}

Benchmarking the two implementations.

bench::mark(
  R_fastRndWalk(),
  Rcpp_fasterRndWalk(),
  time_unit="s",
  iterations=10,
  check=FALSE)

# A tibble: 2 × 13
  expression             min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory                   time            gc               
  <bch:expr>           <dbl>  <dbl>     <dbl> <bch:byt>    <dbl> <int> <dbl>      <dbl> <list> <list>                   <list>          <list>           
1 R_fastRndWalk()      14.4   14.4     0.0684    3.25GB    0.212    10    31      146.  <NULL> <Rprofmem [128,647 × 3]> <bench_tm [10]> <tibble [10 × 3]>
2 Rcpp_fasterRndWalk()  3.71   3.75    0.267    30.67MB    0.160    10     6       37.4 <NULL> <Rprofmem [22,813 × 3]>  <bench_tm [10]> <tibble [10 × 3]>

The C++ implementation is almost 4 times faster and uses about 100 times less memory.

The results are slightly different.

> R_fastRndWalk()[1:10, 1:10]
           [,1]       [,2]       [,3]      [,4]      [,5]       [,6]       [,7]      [,8]      [,9]       [,10]
gs1   1460.4123  786.42436  848.96567 1652.6371 1699.5721  868.81145   99.72889 1147.2828  403.7456  951.345237
gs2   1351.3376 1589.24939 1166.46396 -813.0662 1193.3584 1436.63892 -671.25817 1470.0535 1757.5663 1528.896167
gs3    989.5343 -724.26180  988.38412 1285.1941 1725.1857  755.75058  413.20320 1216.5852 1103.2568  860.702336
gs4   2270.6517  913.43945 1786.67177 1630.7358  832.9891  888.71093 1349.89766 1286.1466 1130.9348  347.736033
gs5    965.6410 1770.05765 -718.08133  631.8339 1105.8057 2099.18587  853.27931 1738.7601 -361.8615 1696.477819
gs6   1241.1693  605.35762 1390.54474  218.5366 1603.1661 1064.22024  738.78739 1321.9661 1595.9738  866.650390
gs7   1821.4801 1105.39881 1805.02746  676.4591  738.2390 1670.38658  800.48911 1655.6888 1616.6367 1087.332425
gs8  -1150.0557 3321.83607 -297.52957 2636.5804 2193.6280 1574.25666 1273.75154  693.2889  918.0266 2620.212948
gs9    966.0466   39.91441  -81.57265 -301.6164 1266.5278  751.08846  846.78107 1121.3221  348.3009   -2.232159
gs10  1701.1213 2497.79930 1937.79704  224.0872 2921.6033  -65.45541 1132.45168 1988.8786  482.1426 1195.779765

> Rcpp_fasterRndWalk()[1:10, 1:10]
           [,1]       [,2]       [,3]      [,4]      [,5]       [,6]      [,7]      [,8]      [,9]       [,10]
gs1   1461.2588  787.11592  849.46221 1653.3682 1700.3948  869.71002  100.5005 1147.6113  404.5719  951.448069
gs2   1351.7445 1590.10813 1167.39739 -812.7176 1193.5546 1436.89530 -670.6217 1470.8496 1758.2667 1529.485934
gs3    990.3381 -723.26908  989.34011 1285.4300 1725.7057  756.63913  413.2240 1217.1778 1103.9727  861.649998
gs4   2271.2896  914.25031 1787.42785 1631.2987  833.0586  889.30922 1350.6158 1286.2566 1131.0374  348.201812
gs5    966.2519 1770.93258 -718.06508  631.9691 1105.8603 2099.70185  853.5787 1739.6593 -361.5559 1696.634394
gs6   1242.0753  605.57559 1391.00445  219.5044 1603.2600 1064.83204  739.1218 1322.4162 1596.2087  867.207822
gs7   1822.1339 1105.67133 1805.62463  677.3429  738.5166 1670.98240  801.3700 1655.7198 1617.6060 1087.722861
gs8  -1149.3512 3321.87337 -297.04575 2637.5285 2194.2728 1574.87523 1273.8755  693.7584  918.8056 2620.856900
gs9    966.5787   40.34541  -80.57416 -301.3876 1267.3939  751.96272  846.8289 1121.3280  348.5236   -1.692659
gs10  1701.4705 2498.41092 1938.07949  224.7069 2922.3907  -64.56492 1132.6352 1989.6011  483.1052 1195.850440

My C++ is rusty (because of Rust) and I know very little C, so I imagine someone else could improve this further or reimplement it in C and avoid any more dependencies. I'm not too proud to admit that I needed ChatGPT to debug a line of code for me here.

Some relevant versions.

> R.Version()$version.string
[1] "R version 4.2.1 (2022-06-23)"
> packageVersion("GSVA")
[1] ‘1.46.0’
> packageVersion("Rcpp")
[1] ‘1.0.9’
> packageVersion("BiocParallel")
[1] ‘1.32.5’

Cheers, Bob

rcastelo commented 1 year ago

Dear Bob (@rpolicastro),

Thank you very much for your suggestion, we'll certainly look carefully at it and will try to incorporate it into the code base. I'll keep you posted here.

rcastelo commented 6 months ago

Dear Bob (@rpolicastro),

As the previous automated messages suggest, your suggestion to improve performance in the function .fastRndWalk() has been implemented in the latest release of GSVA (1.52.x), which came out on May 1st, 2024, although using R code only, which in our benchmarkings was bringing a comparable improvement in performance as the Rcpp or C counterpart, i.e., running one order or magnitude faster and consuming one order or magnitude less memory. Thanks again for bringing up this performance bottleneck, which has been now greatly reduced.

rpolicastro commented 6 months ago

Fantastic, I'm glad it worked out! Taking a peak at the code changes the R fix was rather simple and elegant. I'm impressed that's all it ended up taking.

Cheers, Bob