topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

ranger with weights in RFE #1323

Closed r2j2ritson closed 1 year ago

r2j2ritson commented 1 year ago

I am trying to use recursive feature elimination on weighted classification random forest with ranger. In my data, I want to weight observations as inverse of sample size, but for a reproducible example, let's just say Class1 weighted twice as much as Class2. The model trains successfully with and without weights argument, but when I try to run it with recursive feature elimination, the model containing weights returns an error. I think for some reason the weights argument does not get updated within the RFE algorithm, or am I missing something? I'm unsure if its a bug in rfe() or if I'm missing an argument somewhere.

Minimal, runnable code:

# Load Packages
require(dplyr)
require(ranger)
require(caret)

#Create dummy data
set.seed(1)
dat <- caret::twoClassSim()

#Args
outcomeName <- "Class"
predictorsNames <- colnames(dat)[-length(colnames(dat))]
sizes <- c(1:(length(predictorsNames)-1))
x <- dat[,c(predictorsNames)]
y <- as.factor(dat[,c(outcomeName)])
wts <- dat %>% 
  dplyr::group_by(Class) %>%
  dplyr::summarise(wts = ifelse(Class %in% c("Class1"),2,1)) %>% # weight Class1 above Class2
  dplyr::ungroup(.) %>%
  dplyr::select(wts) %>%
  unlist(.) %>%
  as.vector(.)

# Normal model 
out <- caret::train(x, y,
                    method="ranger",
                    importance="permutation",
                    metric="ROC",
                    trControl= caret::trainControl(method = "repeatedcv",
                                                   summaryFunction = caret::twoClassSummary,
                                                   classProbs = TRUE, 
                                                   number = 3,
                                                   repeats = 3,
                                                   verboseIter = TRUE,
                                                   savePredictions = TRUE,
                                                   allowParallel = TRUE),
                    tuneGrid = expand.grid(min.node.size = 1,
                                           mtry = floor(sqrt(length(x))),
                                           splitrule = "hellinger"),
                    num.trees = 500,
                    max.depth = 18,
                    oob.error = T,
                    keep.inbag = T,
                    seed=1)
print(out) #successfull

# Normal model w/ wts
out_wt <- caret::train(x, y,
                    method="ranger",
                    importance="permutation",
                    metric="ROC",
                    trControl= caret::trainControl(method = "repeatedcv",
                                                   summaryFunction = caret::twoClassSummary,
                                                   classProbs = TRUE, 
                                                   number = 3,
                                                   repeats = 3,
                                                   verboseIter = TRUE,
                                                   savePredictions = TRUE,
                                                   allowParallel = TRUE),
                    tuneGrid = expand.grid(min.node.size = 1,
                                           mtry = floor(sqrt(length(x))),
                                           splitrule = "hellinger"),
                    num.trees = 500,
                    max.depth = 18,
                    weights = wts,
                    oob.error = T,
                    keep.inbag = T,
                    seed=1)
print(out_wt) #successful

## Let's try it with RFE
# Set-up Control functions of RFE w/ ranger
n_inner_folds <- 3
n_outer_folds <- 2
rpts <- 3
n_trees <- 500
max_depth <- 18

# Control w/o weights
cntrl_ranger <- caret::rfeControl(
  functions = list(fit = function(x, y, first, last, ...) {
    out <- caret::train(x, y,
                        method="ranger",
                        importance="permutation",
                        metric="ROC",
                        trControl= caret::trainControl(method = "repeatedcv",
                                                       summaryFunction = caret::twoClassSummary,
                                                       classProbs = TRUE, 
                                                       number = n_inner_folds,
                                                       repeats = rpts,
                                                       verboseIter = TRUE,
                                                       savePredictions = TRUE,
                                                       allowParallel = TRUE),
                        tuneGrid = expand.grid(min.node.size = 1,
                                               mtry = floor(sqrt(length(x))),
                                               splitrule = "hellinger"),
                        num.trees = n_trees,
                        max.depth = max_depth,
                        oob.error = T,
                        keep.inbag = T)
    return(out)
  },
  pred = function(object, x) {
    out <- as.data.frame(predict(object$finalModel, data=x, type="response")$predictions)
    out$pred <- factor(colnames(out)[apply(out,1,which.max)], levels=sort(colnames(out)))
    rownames(out) <- rownames(x)
    return(out)
  },
  rank = function(object, x, y) {
    varimps <- as.data.frame(ranger::importance(object$finalModel))
    varimps$var <- rownames(varimps)
    colnames(varimps) <- c("importance","var")
    varimps <- varimps[order(varimps$importance, decreasing=TRUE),]
    colnames(varimps)[[1]] <- "Overall"
    return(varimps)
  },
  selectSize = caret::pickSizeBest,
  selectVar = caret::pickVars,
  summary = caret::twoClassSummary),
  rerank = T,
  method = 'cv',
  number = n_outer_folds,
  returnResamp = "final",
  verbose = T,
  saveDetails = T,
  allowParallel = T
)

# Control w/ weights
cntrl_ranger_wt <- caret::rfeControl(
  functions = list(fit = function(x, y, first, last, ...) {
    out <- caret::train(x, y,
                        method="ranger",
                        importance="permutation",
                        metric="ROC",
                        trControl= caret::trainControl(method = "repeatedcv",
                                                       summaryFunction = caret::twoClassSummary,
                                                       classProbs = TRUE, 
                                                       number = n_inner_folds,
                                                       repeats = rpts,
                                                       verboseIter = TRUE,
                                                       savePredictions = TRUE,
                                                       allowParallel = TRUE),
                        tuneGrid = expand.grid(min.node.size = 1,
                                               mtry = floor(sqrt(length(x))),
                                               splitrule = "hellinger"),
                        num.trees = n_trees,
                        max.depth = max_depth,
                        weights = wts,
                        oob.error = T,
                        keep.inbag = T)
    return(out)
  },
  pred = function(object, x) {
    out <- as.data.frame(predict(object$finalModel, data=x, type="response")$predictions)
    out$pred <- factor(colnames(out)[apply(out,1,which.max)], levels=sort(colnames(out)))
    rownames(out) <- rownames(x)
    return(out)
  },
  rank = function(object, x, y) {
    varimps <- as.data.frame(ranger::importance(object$finalModel))
    varimps$var <- rownames(varimps)
    colnames(varimps) <- c("importance","var")
    varimps <- varimps[order(varimps$importance, decreasing=TRUE),]
    colnames(varimps)[[1]] <- "Overall"
    return(varimps)
  },
  selectSize = caret::pickSizeBest,
  selectVar = caret::pickVars,
  summary = caret::twoClassSummary),
  rerank = T,
  method = 'cv',
  number = n_outer_folds,
  returnResamp = "final",
  verbose = T,
  saveDetails = T,
  allowParallel = T
)

# Execute RFE w/o weights
set.seed(1)
model_rfe <- caret::rfe(x, y,
                        metric="ROC",
                        maximize=T,
                        sizes = sizes,
                        rfeControl = cntrl_ranger,
                        verbose = T)
print(model_rfe)

model_rfe_wt <- caret::rfe(x, y,
                        metric="ROC",
                        maximize=T,
                        sizes = sizes,
                        rfeControl = cntrl_ranger_wt,
                        verbose = T)

Error in { : task 1 failed - "replacement has 100 rows, data has 50"

Session Info:

>sessionInfo()

R version 4.2.1 (2022-06-23 ucrt) Platform: x86_64-w64-mingw32/x64 (64-bit) Running under: Windows 10 x64 (build 19044)

Matrix products: default

locale: [1] LC_COLLATE=English_United States.utf8 LC_CTYPE=English_United States.utf8
[3] LC_MONETARY=English_United States.utf8 LC_NUMERIC=C
[5] LC_TIME=English_United States.utf8

attached base packages: [1] stats graphics grDevices utils datasets methods base

other attached packages: [1] caret_6.0-93 lattice_0.20-45 ggplot2_3.4.0 ranger_0.14.1 dplyr_1.0.10

loaded via a namespace (and not attached): [1] tidyselect_1.1.2 purrr_0.3.4 reshape2_1.4.4 listenv_0.8.0
[5] splines_4.2.1 colorspace_2.0-3 vctrs_0.5.1 generics_0.1.3
[9] stats4_4.2.1 utf8_1.2.2 survival_3.3-1 prodlim_2019.11.13
[13] rlang_1.0.6 e1071_1.7-11 ModelMetrics_1.2.2.2 pillar_1.8.1
[17] glue_1.6.2 withr_2.5.0 foreach_1.5.2 lifecycle_1.0.3
[21] plyr_1.8.7 lava_1.6.10 stringr_1.4.1 timeDate_4021.104
[25] munsell_0.5.0 gtable_0.3.1 future_1.28.0 recipes_1.0.1
[29] codetools_0.2-18 parallel_4.2.1 class_7.3-20 fansi_1.0.3
[33] Rcpp_1.0.9 scales_1.2.1 ipred_0.9-13 parallelly_1.32.1
[37] digest_0.6.29 stringi_1.7.8 grid_4.2.1 hardhat_1.2.0
[41] cli_3.4.1 tools_4.2.1 magrittr_2.0.3 proxy_0.4-27
[45] tibble_3.1.8 future.apply_1.9.1 pkgconfig_2.0.3 ellipsis_0.3.2
[49] MASS_7.3-57 Matrix_1.5-1 data.table_1.14.6 pROC_1.18.0
[53] lubridate_1.8.0 gower_1.0.0 rstudioapi_0.14 iterators_1.0.14
[57] R6_2.5.1 globals_0.16.1 rpart_4.1.16 nnet_7.3-17
[61] nlme_3.1-157 compiler_4.2.1