tidymodels / yardstick

Tidy methods for measuring model performance
https://yardstick.tidymodels.org/
Other
367 stars 54 forks source link

Classification report #308

Open 1lliter8 opened 2 years ago

1lliter8 commented 2 years ago

Feature

In situations when users want a quick overview of both in-class and model-wide precision, recall, F1 and support, sklearn.metrics.classification_report seems like a very useful function. As far as I can see summary.conf_mat is the closest analogue in {yardstick}, but it's missing in-class metrics.

As part of a project I've built my own classification_report that uses {yardstick}, but I want to check that this is something people would actually want and that fits the project direction before I put in the work to polish it up and make a pull request.

If you'd like me to go ahead, are there strong opinions on how such a function should fit into the package architecture? Or how it should be parameterised to offer metrics beyond merely replicating the Python function?

juliasilge commented 2 years ago

How are they computing those per-class metrics in the multiclass case? Treating wrong predictions as just "wrong" vs. using a multiclass implementation?

library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(yardstick)
#> For binary classification, the first factor level is assumed to be the event.
#> Use the argument `event_level = "second"` to alter this as needed.

preds <- 
  tibble(obs = paste("class", c(0, 1, 2, 2, 0)), 
         pred = paste("class", c(0, 0, 2, 2, 0))) %>%
  mutate(across(everything(), ~ factor(., levels = paste("class", 0:2))))

classification_report <- metric_set(precision, recall, accuracy)

preds %>%
  group_by(obs) %>%
  classification_report(truth = obs, estimate = pred)
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 1': 0
#> 'class 2': 0
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 1': 1
#> 'class 2': 0
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 0': 0
#> 'class 1': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 1': 0
#> 'class 2': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 0': 1
#> 'class 2': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 0': 0
#> 'class 1': 0
#> # A tibble: 9 × 4
#>   obs     .metric   .estimator .estimate
#>   <fct>   <chr>     <chr>          <dbl>
#> 1 class 0 precision macro              1
#> 2 class 1 precision macro              0
#> 3 class 2 precision macro              1
#> 4 class 0 recall    macro              1
#> 5 class 1 recall    macro              0
#> 6 class 2 recall    macro              1
#> 7 class 0 accuracy  multiclass         1
#> 8 class 1 accuracy  multiclass         0
#> 9 class 2 accuracy  multiclass         1

Created on 2022-06-08 by the reprex package (v2.0.1)

Note the warnings about no true events.

1lliter8 commented 2 years ago

Scikit learn computes precision as true positives / all estimates, per class, and recall as true positives / all truths, per class. Practically I've done this using the confusion matrix, with the true positive over the row sums, and the true positive over the column sums. Looking through the _classification.py code in scikit learn, classification_report calls precision_recall_fscore_support on line 2328, which builds a multiclass confusion matrix on line 1563, which is a three dimenensional array of 2x2xn matrices, one per class. These are summed for true positive, estimates and truths in row 1570.

However, this methodology doesn't produce the same results as yours, and I don't understand why.

This is taken from the classification_report documentation page:

>>> from sklearn.metrics import classification_report
>>> y_true = [0, 1, 2, 2, 2]
>>> y_pred = [0, 0, 2, 2, 1]
>>> target_names = ['class 0', 'class 1', 'class 2']
>>> print(classification_report(y_true, y_pred, target_names=target_names))
              precision    recall  f1-score   support

     class 0       0.50      1.00      0.67         1
     class 1       0.00      0.00      0.00         1
     class 2       1.00      0.67      0.80         3

    accuracy                           0.60         5
   macro avg       0.50      0.56      0.49         5
weighted avg       0.70      0.60      0.61         5

This is your methodology above, adapted to match the scikit learn data, plus the confusion matrix, and the output of my (untidied) function.

library(tidyverse)
#> Warning: package 'tidyverse' was built under R version 3.6.3
#> Warning: package 'ggplot2' was built under R version 3.6.3
#> Warning: package 'tibble' was built under R version 3.6.3
#> Warning: package 'tidyr' was built under R version 3.6.3
#> Warning: package 'readr' was built under R version 3.6.3
#> Warning: package 'purrr' was built under R version 3.6.3
#> Warning: package 'dplyr' was built under R version 3.6.3
#> Warning: package 'stringr' was built under R version 3.6.3
#> Warning: package 'forcats' was built under R version 3.6.3
library(yardstick)
#> Warning: package 'yardstick' was built under R version 3.6.3
#> For binary classification, the first factor level is assumed to be the event.
#> Use the argument `event_level = "second"` to alter this as needed.
#> 
#> Attaching package: 'yardstick'
#> The following object is masked from 'package:readr':
#> 
#>     spec
library(assertthat)
#> Warning: package 'assertthat' was built under R version 3.6.3
#> 
#> Attaching package: 'assertthat'
#> The following object is masked from 'package:tibble':
#> 
#>     has_name
library(dtplyr)
#> Warning: package 'dtplyr' was built under R version 3.6.3
source('./R/create_multiclass_report.R') # Edited slightly, my C drive!

preds_skl <- 
  tibble(obs = paste("class", c(0, 1, 2, 2, 2)), 
         pred = paste("class", c(0, 0, 2, 2, 1))) %>%
  mutate(across(everything(), ~ factor(., levels = paste("class", 0:2)))) %>% 
  rownames_to_column(var = 'id') %>% # Idiosyncracies of my current function, ignore
  mutate(probability = 1) # Idiosyncracies of my current function, ignore

preds_skl %>% 
  conf_mat(truth = obs, estimate = pred)
#>           Truth
#> Prediction class 0 class 1 class 2
#>    class 0       1       1       0
#>    class 1       0       0       1
#>    class 2       0       0       2

classification_report <- metric_set(precision, recall, f_meas)

preds_skl %>%
  group_by(obs) %>%
  classification_report(truth = obs, estimate = pred) %>% 
  pivot_wider(names_from = .metric, values_from = .estimate)
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 1': 0
#> 'class 2': 0
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 1': 1
#> 'class 2': 0
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 0': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 1': 0
#> 'class 2': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 0': 1
#> 'class 2': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 0': 0
#> 'class 1': 1
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 1': 0
#> 'class 2': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 1': 0
#> 'class 2': 0
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 1': 1
#> 'class 2': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 0': 1
#> 'class 2': 0
#> Warning: While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 
#> Precision is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of true events actually occured for each problematic event level:
#> 'class 0': 0
#> Warning: While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 
#> Recall is undefined in this case, and those levels will be removed from the averaged result.
#> Note that the following number of predicted events actually occured for each problematic event level:
#> 'class 0': 0
#> 'class 1': 1
#> # A tibble: 3 x 5
#>   obs     .estimator precision recall f_meas
#>   <fct>   <chr>          <dbl>  <dbl>  <dbl>
#> 1 class 0 macro            1    1        1  
#> 2 class 1 macro            0    0      NaN  
#> 3 class 2 macro            0.5  0.667    0.8

preds_skl %>% 
  create_multiclass_report(
    id = id,
    truth = obs,
    estimate = pred,
    probability = probability
  ) %>% 
  filter(type == 'class') %>% 
  arrange(category)
#> # A tibble: 3 x 8
#>   type  category n1_t0_precision n1_t0_recall n1_t0_f_meas n1_t0_p_mean
#>   <chr> <chr>              <dbl>        <dbl>        <dbl>        <dbl>
#> 1 class class 0              0.5        1            0.667            1
#> 2 class class 1              0          0          NaN                1
#> 3 class class 2              1          0.667        0.8              1
#> # ... with 2 more variables: n1_t0_p_median <dbl>, support <dbl>

# Forgive my column names!

Created on 2022-06-09 by the reprex package (v2.0.0)