topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.6k stars 634 forks source link

Adding the possibility to plot the confusion matrix #755

Closed svazzole closed 6 years ago

svazzole commented 6 years ago

Hi, I've been searching a way to plot the confusion matrix of a trained model (especially in the multi-class case). I wasn't able to find a good answer (in my opinion), so I decided to write my own function. If someone is interested here you have the code:

plotConfusionMatrix <- function(model, norm = "none"){

  cm <- confusionMatrix(model, norm = norm)

  conf_matrix <- matrix(cm$table, ncol = length(unique(model$trainingData$.outcome)))

  nr <- nrow(conf_matrix)

  M <- t(conf_matrix)[, nr:1]
  Mv <- as.vector(M)
  colnames(M) <- colnames(cm$table)[nr:1]
  rownames(M) <- colnames(cm$table)

  g <- ggplot2::ggplot(reshape2::melt(M), ggplot2::aes_string(x='Var1', y='Var2', fill='value')) + ggplot2::geom_raster() +
       ggplot2::scale_fill_gradient2(low='blue', high='red') + ggplot2::xlab("True") + ggplot2::ylab("Predicted") +
       ggplot2::theme(axis.text.x=ggplot2::element_text(angle=45,hjust=1,vjust=1)) + 
       ggplot2::geom_text(aes(label = round(Mv,2)), vjust = 1)  

  return(g)

}

Hope someone find it useful. Best, Simon

topepo commented 6 years ago

I usually use mosaic plots:

> lvs <- c("normal", "abnormal")
> truth <- factor(rep(lvs, times = c(86, 258)),
+                 levels = rev(lvs))
> pred <- factor(
+                c(
+                  rep(lvs, times = c(54, 32)),
+                  rep(lvs, times = c(27, 231))),
+                levels = rev(lvs))
> 
> xtab <- table(pred, truth)
> 
> confusionMatrix(xtab)
Confusion Matrix and Statistics

          truth
pred       abnormal normal
  abnormal      231     32
  normal         27     54

               Accuracy : 0.828         
                 95% CI : (0.784, 0.867)
    No Information Rate : 0.75          
    P-Value [Acc > NIR] : 0.00031       

                  Kappa : 0.534         
 Mcnemar's Test P-Value : 0.60254       

            Sensitivity : 0.895         
            Specificity : 0.628         
         Pos Pred Value : 0.878         
         Neg Pred Value : 0.667         
             Prevalence : 0.750         
         Detection Rate : 0.672         
   Detection Prevalence : 0.765         
      Balanced Accuracy : 0.762         

       'Positive' Class : abnormal      

> mosaicplot(confusionMatrix(xtab)$table)

The vcd package has a better version called mosaic.

talegari commented 6 years ago

Another option would be using alluvial plots (alluvial package by Michał Bojanowski is handy). I tend to use this when there more than 3 classes:

plotCM <- function(cm){
  cmdf <- as.data.frame(cm[["table"]])
  cmdf[["color"]] <- ifelse(cmdf[[1]] == cmdf[[2]], "green", "red")

  alluvial::alluvial(cmdf[,1:2]
                     , freq = cmdf$Freq
                     , col = cmdf[["color"]]
                     , alpha = 0.5
                     , hide  = cmdf$Freq == 0
                     )
}

Example:

library("mlbench")
data(Soybean)
sb2 <- Soybean[complete.cases(Soybean),]
sb2$Class <- factor(sb2$Class)

caret::train(Class ~., data = sb2, method = "ranger") %>%
 confusionMatrix() %>% 
 plotCM()

produces:

image

topepo commented 6 years ago

Pretty cool.

To be honest, the code to produce that (and the mosaic plot) are pretty straightforward. I'd like to avoid more package dependencies so, for now, I'm going to avoid a new plot method.

However, probably closer to the new year, I'll be breaking out all of the performance measures into a tidy package (similar to what was done with rsample). This won't impact caret at all, but it will allow me to do more without a ton of overhead.

topepo commented 6 years ago

However, probably closer to the new year, I'll be breaking out all of the performance measures into a tidy package

It's amazing what you can get done on a delayed train...

yardstick

I'll add a plot method soon but @talegari is welcome to make a PR in the short term.