marjoleinF / pre

an R package for deriving Prediction Rule Ensembles
58 stars 17 forks source link

Elegantly handle depth 0 trees in biased fits #17

Closed holub008 closed 6 years ago

holub008 commented 6 years ago

Currently, pre() will fail with an out of bounds error if any base tree is of depth 0 with unbiased=FALSE (i.e. using rpart to fit trees) and the feature set includes a factor. Here's a reproducible example:

library(pre)
library(partykit)
library(rpart)

airquality_cp <- airquality
# introduce a factor predictor to exercise relevant code path
airquality_cp$noise <- as.factor(rbinom(nrow(airquality_cp), 1, .5))

# test unbiased tree code path, building stumps by imposing an impossible significance level
tree_control <- ctree_control(alpha=1e-10)
airq.ens.unbiased <- pre(Ozone ~ ., data = na.omit(airquality_cp), 
                         tree.control = tree_control)
coef(airq.ens.unbiased$glmnet.fit)
# this works as expected - a linear model is fit with no rules

# test biased tree code path, building stumps by imposing an impossible error reduction threshold
tree_control <- rpart.control(cp = 1.0)
airq.ens.biased <- pre(Ozone ~ ., data = na.omit(airquality_cp), 
                       tree.unbiased = FALSE,
                       tree.control = tree_control)

The last call results in:

Error in rules[[i]] : subscript out of bounds

This PR changes this behavior to pass through the empty list of rules and fit a 0 rule linear model, to match the unbiased=TRUE case.

setwd('~/pre')
devtools::load_all()
airq.ens.biased <- pre(Ozone ~ ., data = na.omit(airquality_cp), 
                       tree.unbiased = FALSE,
                       tree.control = tree_control)
coef(airq.ens.biased$glmnet.fit)

resulting in:

(Intercept) -43.119109
Solar.R       2.318196
Wind        -21.168268
Temp         32.766033
Month         .       
Day           .       
noise1        .       

Additionally, these changes improves the error message formatting when an invalid tree.control argument is supplied. From:

Error in pre(Ozone ~ ., data = na.omit(airquality_cp), tree.control = rpart.control()) : 
  Argument 'tree.control' should be a list containing named elementscriterionlogmincriterionminsplitminbucketminprobstumpnmaxlookaheadmtrymaxdepthmultiwaysplittrymaxsurrogatenumsurrogatemajoritycaseweightsapplyfunsaveinfobonferroniupdateselectfunsplitfunsvselectfunsvsplitfunteststatsplitstatsplittestpargstesttypenresampletolintersplitMIA
In addition: Warning message:
In sort(names(ctree_control())) == sort(names(tree.control)) :
  longer object length is not a multiple of shorter object length

To:

Error in pre(Ozone ~ ., data = na.omit(airquality_cp), tree.control = rpart.control()) : 
  Argument 'tree.control' should be a list containing named elements criterion, logmincriterion, minsplit, minbucket, minprob, stump, nmax, lookahead, mtry, maxdepth, multiway, splittry, maxsurrogate, numsurrogate, majority, caseweights, applyfun, saveinfo, bonferroni, update, selectfun, splitfun, svselectfun, svsplitfun, teststat, splitstat, splittest, pargs, testtype, nresample, tol, intersplit, MIA