amices / mice

Multivariate Imputation by Chained Equations
https://amices.org/mice/
GNU General Public License v2.0
424 stars 106 forks source link

"literanger" as alternative backend for random forest (faster prediction) #648

Open stephematician opened 4 weeks ago

stephematician commented 4 weeks ago

Offer the "literanger" package as an alternative backend for faster prediction from random forest models.

literanger is more-or-less the same algorithm as ranger but with a refactored interface that reduces overhead in prediction (fewer copy semantics and more generous use of templates on the C++ side).

It uses a marginally different (but generally equivalent) procedure for drawing the predicted value. A randomly selected tree is drawn first (for each missing value) and then a randomly selected observed value is drawn from the leaf node (that the missing value belongs to).

Not all forest types and split rules from the original ranger package, however, are currently supported.

Here's some short examples of the difference in elapsed time on my laptop (Ryzen 4900HS Ubuntu 22.04)

require(microbenchmark)
require(mice)
require(ranger)
require(literanger)
set.seed(1234L)

# Add MCAR to iris
prop_missing <- 0.2
data_iris <- iris
n_prod_m <- prod(dim(data_iris))
data_iris[arrayInd(sample.int(n_prod_m, size=n_prod_m * prop_missing),
                   .dim=dim(data_iris))] <- NA

microbenchmark(
    res <- mice(data_iris, m=5L, maxit=10L, method='rf', printFlag=FALSE,
                rfPackage='ranger'),
    res <- mice(data_iris, m=5L, maxit=10L, method='rf', printFlag=FALSE,
                rfPackage='literanger'),
    times=10L
)
#      min       lq     mean   median       uq      max neval
#  2.690511 2.716136 2.750803 2.726515 2.773808 2.857036    10
#  1.095654 1.108470 1.125039 1.118390 1.123107 1.214325    10

# Add MCAR to a larger, multivariate-normal dataset
n_dim <- 5E1
n_obs <- 5E3
chol_sigma <- matrix(rnorm(n_dim^2), nrow=n_dim)
data_mvn <- matrix(rnorm(n_obs * n_dim), nrow=n_obs, ncol=n_dim) %*% chol_sigma
colnames(data_mvn) <- make.names(seq_len(n_dim))

n_prod_m <- prod(dim(data_mvn))
data_mvn[arrayInd(sample.int(n_prod_m, size=n_prod_m * prop_missing),
                  .dim=dim(data_mvn))] <- NA
data_mvn <- as.data.frame(data_mvn)

microbenchmark(
    res <- mice(data_mvn, m=5L, maxit=10L, method='rf', printFlag=FALSE,
                rfPackage='ranger'),
    res <- mice(data_mvn, m=5L, maxit=10L, method='rf', printFlag=FALSE,
                rfPackage='literanger'),
    times=5L
)
#      min       lq     mean   median       uq      max neval
#  579.4787 579.7766 581.0646 580.3888 582.7379 582.9408     5
#  272.8672 273.7151 274.1525 274.2950 274.7915 275.0935     5