topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

RFE LOOCV issues and classification performance metrics (prSummary and twoClassSummary) #1304

Open mmarcato opened 2 years ago

mmarcato commented 2 years ago

Hi,

I am using RFE with a logistic regression model to predict a factor with two classes. I've had a few issues:

  1. if I set 'LOOCV' in rfeControl, I can't use a recipe in rfe
  2. if I set 'LOOCV' in rfeControl, I can't get prediction results using $pred
  3. 'cv' with number = number of samples should be equivalent to LOOCV, but it won't calculate AUC/ROC - Sens/Spec and Prec/Recall are all reported without problems. Same thing happens if I just give it the indexes.
  4. I think that the RFE function trains a model considering the second class as the positive class (i.e. predicts the probability of second class), but when evaluating the model using prSummary or twoClassSummary the first class is taken as the positive class. If my conclusion is right, I think this should be changed as it is very confusing to have the predicted class by the model not to be the same as the positive class for performance metrics.

See my minimum reproducible code below, you can play around to check issues 1 and 2:

library(caret)
library(dplyr)
set.seed(1)
dat <- twoClassSim(100)
dat$Class # Class 1 and Class2

# Setting PR/ROC as the metric for the Logistic Regression function
lrFuncs$summary <- prSummary # twoClassSummary 

ctrl <- rfeControl(functions = lrFuncs,         # Logistic Regression
                # method = 'LOOCV',             # error here -> predictions are not availabel (model_rfe$pred = NULL) when using LOOCV 
                method = "cv",                  # Using Cross Validation with
                number = 100,                   # Number of folds = number of samples
                saveDetails = TRUE,
                returnResamp = "all",
                rerank = TRUE,
                verbose = FALSE)

# Recursive Feature Elimination with feat_lr
model_rfe <- rfe(Class ~ .,                     # predict Class using all other variables
                data = dat,
                sizes = 1:5,                    # from 1 to 5 variables
                rfeControl = ctrl,
                # metric = "AUC",               # cv with folds = 100 does not calculate AUC and ROC on entire cross validated set
                metric = "Recall",              # optimising Recall
                maximize = TRUE)

print(model_rfe)
colnames(model_rfe$pred)
# getting the prediction results for the optimal model using 1 variable
model_rfe_pred <- model_rfe$pred %>% filter(Variables == 1) %>% select(pred, obs)
# calculating the confusion matrix using Class 2 as positive class does not give me the same results as shown in model_rfe
confusionMatrix(data = model_rfe_pred$pred, reference = model_rfe_pred$obs, positive = "Class2",  mode = "prec_recall")
# however, I get the same results if I swap the positive class to 'Class 1'
confusionMatrix(data = model_rfe_pred$pred, reference = model_rfe_pred$obs, positive = "Class1",  mode = "prec_recall")
# however, this model predicts 'Class 2'
print(model_rfe$fit)

# checking the conclusion above
prob <- predict(model_rfe$fit,       # model
                newdata = dat %>% select(-Class),
                type  = "response")
model_pred <- data.frame("prob" = prob, "obs" = dat$Class)
# if probability > 0.5 -> pred1 predicts Class 1 and pred2 predicts Class 2
model_pred <- model_pred %>% mutate(pred1 = factor(if_else(prob > 0.5, "Class1", "Class2") ))
model_pred <- model_pred %>% mutate(pred2 = factor(if_else(prob > 0.5, "Class2", "Class1") ))

confusionMatrix(data = model_pred$pred1, reference = model_pred$obs, positive = "Class1",  mode = "prec_recall")
confusionMatrix(data = model_pred$pred2, reference = model_pred$obs, positive = "Class1",  mode = "prec_recall")

My apologies in advance if I missed anything, I've been battling with this issue for a week now trying to understand what's happening. Looking forward to hearing from you.

mmarcato commented 2 years ago

This answer might be related to the issue 3 reported above.