Closed svazzole closed 6 years ago
I usually use mosaic plots:
> lvs <- c("normal", "abnormal")
> truth <- factor(rep(lvs, times = c(86, 258)),
+ levels = rev(lvs))
> pred <- factor(
+ c(
+ rep(lvs, times = c(54, 32)),
+ rep(lvs, times = c(27, 231))),
+ levels = rev(lvs))
>
> xtab <- table(pred, truth)
>
> confusionMatrix(xtab)
Confusion Matrix and Statistics
truth
pred abnormal normal
abnormal 231 32
normal 27 54
Accuracy : 0.828
95% CI : (0.784, 0.867)
No Information Rate : 0.75
P-Value [Acc > NIR] : 0.00031
Kappa : 0.534
Mcnemar's Test P-Value : 0.60254
Sensitivity : 0.895
Specificity : 0.628
Pos Pred Value : 0.878
Neg Pred Value : 0.667
Prevalence : 0.750
Detection Rate : 0.672
Detection Prevalence : 0.765
Balanced Accuracy : 0.762
'Positive' Class : abnormal
> mosaicplot(confusionMatrix(xtab)$table)
The vcd
package has a better version called mosaic
.
Another option would be using alluvial plots (alluvial
package by Michał Bojanowski is handy). I tend to use this when there more than 3 classes:
plotCM <- function(cm){
cmdf <- as.data.frame(cm[["table"]])
cmdf[["color"]] <- ifelse(cmdf[[1]] == cmdf[[2]], "green", "red")
alluvial::alluvial(cmdf[,1:2]
, freq = cmdf$Freq
, col = cmdf[["color"]]
, alpha = 0.5
, hide = cmdf$Freq == 0
)
}
Example:
library("mlbench")
data(Soybean)
sb2 <- Soybean[complete.cases(Soybean),]
sb2$Class <- factor(sb2$Class)
caret::train(Class ~., data = sb2, method = "ranger") %>%
confusionMatrix() %>%
plotCM()
produces:
Pretty cool.
To be honest, the code to produce that (and the mosaic plot) are pretty straightforward. I'd like to avoid more package dependencies so, for now, I'm going to avoid a new plot
method.
However, probably closer to the new year, I'll be breaking out all of the performance measures into a tidy
package (similar to what was done with rsample
). This won't impact caret
at all, but it will allow me to do more without a ton of overhead.
Hi, I've been searching a way to plot the confusion matrix of a trained model (especially in the multi-class case). I wasn't able to find a good answer (in my opinion), so I decided to write my own function. If someone is interested here you have the code:
Hope someone find it useful. Best, Simon