topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

In silico Negative Control Testing Results in AUCs Significantly Less than 0.5 #1348

Open tzadat opened 1 year ago

tzadat commented 1 year ago

I have a question that is likely not a bug, but perhaps a technical feature of cross-validation by caret. In the minimal example below, I demonstrate (using mtcars) that identical data submitted with two different labels (A,B) in a single dataframe results in Cross-Validated Model Scores that are substantially higher in the Control label (A) compared to the Case Label (B). I had hoped that for this negative control experiment that the AUC would approximately be 0.5.

Is there a reason for this behavior?

This bias toward AUC < 0.5 is observed when using other approaches such as logistic regression. It is also observed when randomizing the input data to training and when significantly adding identical data to the negative control dataset. It is also observed when using loocv or boot. Adding a small amount of noise to the data does not change the outcome much either.

However, adding a small signal to distinguish B from A correctly generate AUCs above > 0.5, and CV Scores with appropriate directionality.

The results make me think that there is some type of tie mechanism in place in the absence of signal to choose the Control label (A).

Thank you for your time.

Minimal, reproducible example:

Minimal dataset:

library(caret)
library(pROC)
library(randomForest)
set.seed(12345)
# ----
# input ml data is mtcars repeated exactly twice 
# and given A and B labels for CV training
mlInputData = rbind(mtcars,mtcars)
mlInputData$varOfInterest = c(rep("A",nrow(mtcars)), rep("B",nrow(mtcars)))

Minimal, runnable code:


trainControlSettings <- trainControl(method = "repeatedcv",
                                     number          = 5,
                                     repeats         = 5,
                                     verboseIter     = FALSE,
                                     savePredictions = TRUE,
                                     classProbs      = TRUE,
                                     summaryFunction = twoClassSummary,
                                     allowParallel   = FALSE)

thisMLTrainedModel <- train(varOfInterest ~ ., data = mlInputData,
                            method     = "rf",
                            metric     = "ROC",
                            trControl  = trainControlSettings,
                            tuneGrid   = expand.grid(.mtry=floor(sqrt(ncol(mlInputData)-1))),
                            importance = TRUE,
                            trace      = FALSE)
# plot 1
cvScoresPlots <- ggplot(thisMLTrainedModel$pred, aes(x=obs, y=B)) +
    geom_boxplot(aes(color=obs), alpha=0.8) +
    geom_point(aes(color=obs), alpha=0.65)  +
    xlab("True Status") +
    ylab("Model Score [Predicting B Status]")

# plot 2
rocOutput <- roc(obs ~ B,
            data      = thisMLTrainedModel$pred, 
            ci        = TRUE, 
            direction = "<",
            plot=TRUE, 
            boot.n=100, 
            ci.alpha=0.6,
            stratified=FALSE,
            auc.polygon=FALSE,
            print.auc=TRUE)
sens.ci <- ci.se(rocOutput, specificities=seq(0, 1, .01))
plot(sens.ci, type="shape", col="lightblue")

Plot 1: plot1

Plot 2:

plot2
sessionInfo()
R version 4.1.1 (2021-08-10)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Monterey 12.0.1

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] randomForest_4.7-1 pROC_1.18.0        caret_6.0-90       lattice_0.20-45    ggplot2_3.3.6