mayer79 / missRanger

Fast multivariate imputation by random forests.
https://mayer79.github.io/missRanger/
GNU General Public License v2.0
61 stars 11 forks source link

Allow out-of-sample application #58

Closed jeandigitale closed 1 month ago

jeandigitale commented 9 months ago

Hi,

I'm excited about the new keep_forests option. I was hoping to use it to train imputation forests on a training set and then use those models to impute on a test set. However, when I try, I get an error that I am missing data in other columns in the test set and therefore can't predict out the imputations for a given variable. Is there any way around that?

Note that I think in the documentation it says: "Only relevant when data_only = TRUE (and when forests are grown)." I think you meant FALSE.

Thanks!

mayer79 commented 9 months ago

Good catch, thanks, also for the typo.

Having a predict() function would be very neat. As far as I know, there is no "official" way to do so. Here is a sketch:

  1. Impute first randomly from non-missing values of original "training" data used to fit missRanger().
  2. Apply predictions iteratively, say, three times.
library(missRanger)

irisWithNA <- generateNA(iris, seed = 34)

.in <- c(1:40, 51:90, 101:140)
data_train <- irisWithNA[.in, ]

imp <- missRanger(
  irisWithNA[.in, ], pmm.k = 3, num.trees = 100, data_only = FALSE, keep_forests = TRUE
)

newdata <- irisWithNA[-.in, ]

# data_train is the original unimputed dataset used to fit missRanger(). 
# Will add it to the "missRanger" object later to simplify the API
predict.missRanger <- function(x, newdata, data_train, n_iter = 3, pmm.k = 5) {
  to_fill <- is.na(newdata[x$visit_seq])
  to_fill_train <- is.na(data_train)

  # Initialize by randomly picking from original data
  for (v in x$visit_seq) {
    m <- sum(to_fill[, v])
    newdata[[v]][to_fill[, v]] <- sample(
      data_train[[v]][!to_fill_train[, v]], size = m, replace = TRUE
    )
  }

  for (i in seq_len(n_iter)) {
    for (v in x$visit_seq) {
      v_na <- to_fill[, v]
      pred <- predict(x$forests[[v]], newdata[v_na, ])$predictions
      if (pmm.k > 0) {
        pred <- pmm(
          xtrain = x$forests[[v]]$predictions, 
          xtest = pred, 
          ytrain = data_train[[v]][!is.na(data_train[[v]])], 
          k = pmm.k
        )
      }
      newdata[v_na, v] <- pred
    }
  }
  newdata
}

out <- predict(imp, new_data, data_train = data_train)
head(out)
head(iris[.in, ])

# Did not change existing values? (Should be TRUE)
all(out[!is.na(newdata)] == newdata[!is.na(newdata)])

# Any missings left? Should be FALSE
anyNA(out)
mayer79 commented 1 month ago

@jeandigitale Just wanted to let you know that initial support is planned for the next release.