topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

Feature Request: Enable use of Tukey's bi-weights when using method='rlm' #1295

Open dr-consulting opened 2 years ago

dr-consulting commented 2 years ago

GIVEN a user wants to use caret::train() to fit a robust linear model (in this case using MASS:rlm()) WHEN they pass method='MM' to the tuneGrid argument

train(...,
      tuneGrid = data.frame(intercept = TRUE,
                            psi ='psi.bisquare', 
                            method = 'MM'),
      ...
)

THEN caret::train() is able to pass that setting down to MASS::rlm() SO the user can get the cross-validated results based on their desired weighting methodology


Proof of concept

I was able to do what I wanted here, but it required a little bit of hackery:

First, I grabbed the caret "recipe" for the rlm model and modified some of its structure to work with the rest of caret's functions

# Access the existing rlm recipe
rlm_info <- getModelInfo(model='rlm')

# Create a single-row data.frame that we then append to the existing set of parameters recognized by caret
mm_df <- data.frame(parameter = 'method', class = 'character', label = 'method')
rlm_info$rlm$parameters <- rbind(rlm_info$rlm$parameters, mm_df)

Next, I modified the rlm_info$fit function so that it could properly parse the 'MM' option when provided.

# Modify the fit function to parse method from the param object. 
# Note this exact function was copied from the original recipe with the only addition being a pathway to control
# the method argument passed to MASS::rlm() 
rlm_info$rlm$fit <- function(x, y, wts, param, lev, last, classProbs, ...) {
    dat <- if(is.data.frame(x)) x else as.data.frame(x, stringsAsFactors = TRUE)
    dat$.outcome <- y

    psi <- MASS::psi.huber # default
    if (param$psi == "psi.bisquare") {
        psi <- MASS::psi.bisquare
    }
    if (param$psi == "psi.hampel")
        psi <- MASS::psi.hampel

    estimator <- ifelse(param$method == 'M', 'M', 'MM')

    if(!is.null(wts))
    {
        if (param$intercept)
            out <- MASS::rlm(.outcome ~ ., data = dat, weights = wts, psi = psi, method = estimator, ...)
        else
            out <- MASS::rlm(.outcome ~ 0 + ., data = dat, weights = wts, psi = psi, method = estimator, ...)
    } else
    {
        if (param$intercept)
            out <- MASS::rlm(.outcome ~ ., data = dat, psi = psi, method = estimator, ...)
        else
            out <- MASS::rlm(.outcome ~ 0 + ., data = dat, psi = psi, method = estimator, ...)
    }
    out
}

Next, I modified rlm_info$rlm$grid - so that 'MM' was available

rlm_info$rlm$grid <- function(x, y, len=NULL, search='grid'){
    expand.grid(intercept = c(FALSE, TRUE),
                psi = c('psi.huber', 'psi.hampler', 'psi.bisquare'), 
                method = c('M', 'MM'))
}

Then I trained the model with the desired configuration.

train_ctrl <- trainControl(method = 'repeatedcv', number = 5, repeats = 1000, savePredictions = TRUE)

result <- train(
    y ~ x, 
    data = df, 
    method = rlm_info[[1]], 
    trControl = train_ctrl, 
    tuneGrid = data.frame(intercept = TRUE,
                          psi ='psi.bisquare',
                          method = 'MM')
)