tidymodels / yardstick

Tidy methods for measuring model performance
https://yardstick.tidymodels.org/
Other
367 stars 54 forks source link

[Feature request or question] Provide option for metric analysis by class in a multi-class setting #326

Open deschen1 opened 1 year ago

deschen1 commented 1 year ago

Not sure if already available easily, but when having a multi-class classification problem, then it would be good to automatically get the performance on several metrics separately for each class, not just combined. Currently, metric_set just returns a combined/aggregated (?) analysis.

(Sorry for the long reprex):

iris

model_recipe <- recipes::recipe(Species ~ ., data = iris)

# Create a workflow
model_final <- parsnip::naive_Bayes(Laplace = 1) |>
  parsnip::set_mode("classification") |>
  parsnip::set_engine("klaR",
                      prior = rep(1/3, 3),
                      usekernel = FALSE)

model_final_wf <- workflows::workflow() |>
  workflows::add_recipe(model_recipe) |>
  workflows::add_model(model_final)

train_fit <- model_final_wf |>
  generics::fit(data = iris)

# Add predictions
train_predictions <- predict(train_fit, iris, type = "prob") |> 
  dplyr::mutate(class_pred = as.factor(apply(dplyr::across(tidyselect::everything()), 1, which.max))) |> 
  dplyr::bind_cols(iris)

# Check some metrics
multimetric <- yardstick::metric_set(yardstick::f_meas,
                                     yardstick::accuracy,
                                     yardstick::bal_accuracy,
                                     yardstick::sens,
                                     yardstick::spec,
                                     yardstick::precision,
                                     yardstick::recall,
                                     yardstick::ppv,
                                     yardstick::npv)

train_predictions |>
  dplyr::mutate(Species = dplyr::case_when(.data$Species == "setosa"~ 1,
                                           .data$Species == "versicolor" ~ 2,
                                           .data$Species == "virginica" ~ 3),
                Species = as.factor(.data$Species)) |> 
  multimetric(truth    = .data$Species,
              estimate = .data$class_pred)

# A tibble: 9 Γ— 3
  .metric      .estimator .estimate
  <chr>        <chr>          <dbl>
1 f_meas       macro           0.96
2 accuracy     multiclass      0.96
3 bal_accuracy macro           0.97
4 sens         macro           0.96
5 spec         macro           0.98
6 precision    macro           0.96
7 recall       macro           0.96
8 ppv          macro           0.96
9 npv          macro           0.98

So additionally to thsi analysis, it would be great to be able to see e.g. accuracy for each class. Is that already doable, am I missing an easy way of getting there, or could this be added since it's a pretty common and standard step to check class performance in a multi-class setting.

EmilHvitfeldt commented 1 year ago

Hello @deschen1 πŸ‘‹

You should already be able to do this with the current version of yardstick. All the metrics works on grouped data.frames so you can call

train_predictions |>
  group_by(Species) |>
  multimetric(truth = Species, estimate = .pred_class)

To get the results you want. I also took the liberaty to show the augment() which is a great function to combine your original data with predictions. For classification models it will return predicted probabilities and classes. You will get some warnings because if you stratify your calculation by outcome, you will have some levels with no true events, given undefined results.

library(tidymodels)
library(discrim)

model_recipe <- recipe(Species ~ ., data = iris)

# Create a workflow
model_final <- naive_Bayes(Laplace = 1) |>
  set_mode("classification") |>
  set_engine("klaR", prior = rep(1/3, 3), usekernel = FALSE)

model_final_wf <- workflow() |>
  add_recipe(model_recipe) |>
  add_model(model_final)

train_fit <- fit(model_final_wf, data = iris)

# Add predictions
train_predictions <- augment(train_fit, iris)
train_predictions
#> # A tibble: 150 Γ— 9
#>    Sepal.Len…¹ Sepal…² Petal…³ Petal…⁴ Species .pred…⁡ .pred…⁢ .pred_…⁷ .pred_…⁸
#>          <dbl>   <dbl>   <dbl>   <dbl> <fct>   <fct>     <dbl>    <dbl>    <dbl>
#>  1         5.1     3.5     1.4     0.2 setosa  setosa     1    2.98e-18 2.15e-25
#>  2         4.9     3       1.4     0.2 setosa  setosa     1    3.17e-17 6.94e-25
#>  3         4.7     3.2     1.3     0.2 setosa  setosa     1    2.37e-18 7.24e-26
#>  4         4.6     3.1     1.5     0.2 setosa  setosa     1    3.07e-17 8.69e-25
#>  5         5       3.6     1.4     0.2 setosa  setosa     1    1.02e-18 8.89e-26
#>  6         5.4     3.9     1.7     0.4 setosa  setosa     1.00 2.72e-14 4.34e-21
#>  7         4.6     3.4     1.4     0.3 setosa  setosa     1    2.32e-17 7.99e-25
#>  8         5       3.4     1.5     0.2 setosa  setosa     1    1.39e-17 8.17e-25
#>  9         4.4     2.9     1.4     0.2 setosa  setosa     1    1.99e-17 3.61e-25
#> 10         4.9     3.1     1.5     0.1 setosa  setosa     1    7.38e-18 3.62e-25
#> # … with 140 more rows, and abbreviated variable names ¹​Sepal.Length,
#> #   ²​Sepal.Width, ³​Petal.Length, ⁴​Petal.Width, ⁡​.pred_class, ⁢​.pred_setosa,
#> #   ⁷​.pred_versicolor, ⁸​.pred_virginica

# Check some metrics
multimetric <- metric_set(
  yardstick::f_meas,
  yardstick::accuracy,
  yardstick::bal_accuracy,
  yardstick::sens,
  yardstick::spec,
  yardstick::precision,
  yardstick::recall,
  yardstick::ppv,
  yardstick::npv
)

train_predictions |>
  group_by(Species) |>
  multimetric(truth = Species, estimate = .pred_class)
#> # A tibble: 27 Γ— 4
#>    Species    .metric      .estimator .estimate
#>    <fct>      <chr>        <chr>          <dbl>
#>  1 setosa     f_meas       macro          1    
#>  2 versicolor f_meas       macro          0.969
#>  3 virginica  f_meas       macro          0.969
#>  4 setosa     accuracy     multiclass     1    
#>  5 versicolor accuracy     multiclass     0.94 
#>  6 virginica  accuracy     multiclass     0.94 
#>  7 setosa     bal_accuracy macro         NA    
#>  8 versicolor bal_accuracy macro         NA    
#>  9 virginica  bal_accuracy macro         NA    
#> 10 setosa     sens         macro          1    
#> # … with 17 more rows

Created on 2022-10-25 with reprex v2.0.2

deschen1 commented 1 year ago

Thanks a lot @EmilHvitfeldt . One more question: how could I add the roc_auc for each class?

Simply adding yardstick::roc_auc to the multimetric set does not work, which is, I think, because it is a metric that takes the class probabilities. However, even if I do a standalone metric it doesn't work:

train_predictions |>
  roc_auc(truth = Species,
          estimate = .pred_setosa:.pred_virginica)

Adding a group_by(Species) to it gives a lot of warnings and returns NaN for each class.

and only gives me one combined area. However, when using roc_curve + autoplot I will get one curve (and hence AUC) for each level.

DavisVaughan commented 1 year ago

I think this is the same as https://github.com/tidymodels/yardstick/issues/4

i.e. it sounds like you want one-vs-all output (with regard to factor levels). Like:

# truth
"a" "b" "c" "c" "a"

# for a, recode as
"y" "n" "n" "n" "y"

# for b, recode as
"n" "y" "n" "n" "n"

# for c, recode as
"n" "n" "y" "y" "n"

Then do the same thing for estimate, and compute 3 separate binary accuracy() calculations for each set of recoded values. yardstick:::one_vs_all_impl() is an internal helper that does this because we need to do those computations to compute some macro/micro estimates, but the kind of output that we'd get back doesn't natively fit into the rest of our API so I decided not to add it.

It is possible it could be a separate helper function that just wouldn't be usable in tuning and couldn't be combined in a metric set. It could maybe look something like:

one_vs_all <- function(data, truth, estimate, fn) {
 # impl
}

one_vs_all(data, Species, pred, metric_set(accuracy, precision))
#> # A tibble: 6 Γ— 4
#>   .level     .metric   .estimator .estimate
#>   <chr>      <chr>     <chr>          <dbl>
#> 1 setosa     accuracy  binary           0.7
#> 2 setosa     precision binary           0.4
#> 3 versicolor accuracy  binary           0.8
#> 4 versicolor precision binary           0.3
#> 5 virginica  accuracy  binary           0.2
#> 6 virginica  precision binary           0.5

We'd have to consider how the API would look for numeric/class/class-prob metrics, because you supply estimate for some and ... for class-prob metrics.

deschen1 commented 1 year ago

Thanks Davis!

Not entirely sure if what you are saying is that I can get what I want by using these already implemented (?) functions or if this is something you are going to add to the package?

DavisVaughan commented 1 year ago

You cannot currently do it.

I am not sure if we should expose these utilities or not, but it has come up a few times so it might be worth it to expose them in a limited way.

deschen1 commented 1 year ago

Got it. I think it would generally be a great idea to offer class-specific metrics in a multi-class setting. From what I've seen in other tools and for certain analyses, it is a common but also very important task.

E.g. in a market research study where you segment/cluster customers, there might be a very important class that you want to predict very well even if that means you are missing out on some other classes. But you could only judge by seeing the class-specific metrics.

lardenoije commented 11 months ago

Hello @deschen1 πŸ‘‹

You should already be able to do this with the current version of yardstick. All the metrics works on grouped data.frames so you can call

train_predictions |>
  group_by(Species) |>
  multimetric(truth = Species, estimate = .pred_class)

To get the results you want. I also took the liberaty to show the augment() which is a great function to combine your original data with predictions. For classification models it will return predicted probabilities and classes. You will get some warnings because if you stratify your calculation by outcome, you will have some levels with no true events, given undefined results.

library(tidymodels)
library(discrim)

model_recipe <- recipe(Species ~ ., data = iris)

# Create a workflow
model_final <- naive_Bayes(Laplace = 1) |>
  set_mode("classification") |>
  set_engine("klaR", prior = rep(1/3, 3), usekernel = FALSE)

model_final_wf <- workflow() |>
  add_recipe(model_recipe) |>
  add_model(model_final)

train_fit <- fit(model_final_wf, data = iris)

# Add predictions
train_predictions <- augment(train_fit, iris)
train_predictions
#> # A tibble: 150 Γ— 9
#>    Sepal.Len…¹ Sepal…² Petal…³ Petal…⁴ Species .pred…⁡ .pred…⁢ .pred_…⁷ .pred_…⁸
#>          <dbl>   <dbl>   <dbl>   <dbl> <fct>   <fct>     <dbl>    <dbl>    <dbl>
#>  1         5.1     3.5     1.4     0.2 setosa  setosa     1    2.98e-18 2.15e-25
#>  2         4.9     3       1.4     0.2 setosa  setosa     1    3.17e-17 6.94e-25
#>  3         4.7     3.2     1.3     0.2 setosa  setosa     1    2.37e-18 7.24e-26
#>  4         4.6     3.1     1.5     0.2 setosa  setosa     1    3.07e-17 8.69e-25
#>  5         5       3.6     1.4     0.2 setosa  setosa     1    1.02e-18 8.89e-26
#>  6         5.4     3.9     1.7     0.4 setosa  setosa     1.00 2.72e-14 4.34e-21
#>  7         4.6     3.4     1.4     0.3 setosa  setosa     1    2.32e-17 7.99e-25
#>  8         5       3.4     1.5     0.2 setosa  setosa     1    1.39e-17 8.17e-25
#>  9         4.4     2.9     1.4     0.2 setosa  setosa     1    1.99e-17 3.61e-25
#> 10         4.9     3.1     1.5     0.1 setosa  setosa     1    7.38e-18 3.62e-25
#> # … with 140 more rows, and abbreviated variable names ¹​Sepal.Length,
#> #   ²​Sepal.Width, ³​Petal.Length, ⁴​Petal.Width, ⁡​.pred_class, ⁢​.pred_setosa,
#> #   ⁷​.pred_versicolor, ⁸​.pred_virginica

# Check some metrics
multimetric <- metric_set(
  yardstick::f_meas,
  yardstick::accuracy,
  yardstick::bal_accuracy,
  yardstick::sens,
  yardstick::spec,
  yardstick::precision,
  yardstick::recall,
  yardstick::ppv,
  yardstick::npv
)

train_predictions |>
  group_by(Species) |>
  multimetric(truth = Species, estimate = .pred_class)
#> # A tibble: 27 Γ— 4
#>    Species    .metric      .estimator .estimate
#>    <fct>      <chr>        <chr>          <dbl>
#>  1 setosa     f_meas       macro          1    
#>  2 versicolor f_meas       macro          0.969
#>  3 virginica  f_meas       macro          0.969
#>  4 setosa     accuracy     multiclass     1    
#>  5 versicolor accuracy     multiclass     0.94 
#>  6 virginica  accuracy     multiclass     0.94 
#>  7 setosa     bal_accuracy macro         NA    
#>  8 versicolor bal_accuracy macro         NA    
#>  9 virginica  bal_accuracy macro         NA    
#> 10 setosa     sens         macro          1    
#> # … with 17 more rows

Created on 2022-10-25 with reprex v2.0.2

Hi @EmilHvitfeldt, is it possible with this solution to also get the standard error of the metrics? Basically, is it possible to get the same style of output that collect_metrics provides, but per class?

viv-analytics commented 10 months ago

Hi @EmilHvitfeldt , @DavisVaughan

Following up the above information on category-wise metrics, I've found a potential issue when using the estimator = "macro_weighted".


I've used a slightly different data set:

library(tidymodels)
library(palmerpenguins)

conflicts_prefer(palmerpenguins::penguins)

penguins_split <- initial_split(penguins, strata = "species")
penguins_train <- training(penguins_split)

rf_spec <- rand_forest() |>
    set_mode("classification")

results <- 
    workflow(preprocessor = recipe(species ~ island + year, data = penguins_train),
                     spec         =rf_spec  |>
    last_fit(penguins_split)

results

# A tibble: 86 Γ— 8
   id               .pred_Adelie .pred_Chinstrap .pred_Gentoo  .row .pred_class species .config             
   <chr>                   <dbl>           <dbl>        <dbl> <int> <fct>       <fct>   <chr>               
 1 train/test split        0.788           0.152       0.0599     3 Adelie      Adelie  Preprocessor1_Model1
 2 train/test split        0.788           0.152       0.0599     9 Adelie      Adelie  Preprocessor1_Model1
 3 train/test split        0.788           0.152       0.0599    15 Adelie      Adelie  Preprocessor1_Model1
 4 train/test split        0.788           0.152       0.0599    19 Adelie      Adelie  Preprocessor1_Model1
 5 train/test split        0.457           0.484       0.0599    31 Chinstrap   Adelie  Preprocessor1_Model1
 6 train/test split        0.457           0.484       0.0599    32 Chinstrap   Adelie  Preprocessor1_Model1
 7 train/test split        0.457           0.484       0.0599    36 Chinstrap   Adelie  Preprocessor1_Model1
 8 train/test split        0.457           0.484       0.0599    40 Chinstrap   Adelie  Preprocessor1_Model1
 9 train/test split        0.457           0.484       0.0599    43 Chinstrap   Adelie  Preprocessor1_Model1
10 train/test split        0.457           0.484       0.0599    49 Chinstrap   Adelie  Preprocessor1_Model1
# ? 76 more rows

Using estimator = "macro":

results |>
    collect_predictions() |>
    precision(
        species ,
        .pred_class,
        estimator = "macro"
    )

# A tibble: 1 Γ— 3
  .metric   .estimator .estimate
  <chr>     <chr>          <dbl>
1 precision macro          0.603
results |>
    collect_predictions() |>
    group_by(.pred_class) |>
    precision(
        species ,
        .pred_class,
        estimator = "macro"
    )

# A tibble: 3 Γ— 4
  .pred_class .metric   .estimator .estimate
  <fct>       <chr>     <chr>          <dbl>
1 Adelie      precision macro          0.636
2 Chinstrap   precision macro          0.417
3 Gentoo      precision macro          0.756

Using the estimator = "macro_weighted":

results |>
    collect_predictions() |>
    precision(
        species ,
        .pred_class,
        estimator = "macro_weighted"
    )
# A tibble: 1 Γ— 3
  .metric   .estimator     .estimate
  <chr>     <chr>              <dbl>
1 precision macro_weighted     0.636
results |>
    collect_predictions() |>
    group_by(.pred_class) |>
    precision(
        species ,
        .pred_class,
        estimator = "macro_weighted"
    )

# A tibble: 3 Γ— 4
  .pred_class .metric   .estimator     .estimate
  <fct>       <chr>     <chr>              <dbl>
1 Adelie      precision macro_weighted     0.636
2 Chinstrap   precision macro_weighted     0.417
3 Gentoo      precision macro_weighted     0.756

The values on a class-level for "macro_weighted" are identical to "macro".

On an overall-level its seems to be correct.

dchiu911 commented 3 months ago

I implemented a function to calculate the one-vs-all per-class metrics:

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
#> βœ” broom        1.0.5      βœ” recipes      1.0.10
#> βœ” dials        1.2.1      βœ” rsample      1.2.0 
#> βœ” dplyr        1.1.4      βœ” tibble       3.2.1 
#> βœ” ggplot2      3.5.0      βœ” tidyr        1.3.1 
#> βœ” infer        1.0.6      βœ” tune         1.1.2 
#> βœ” modeldata    1.3.0      βœ” workflows    1.1.4 
#> βœ” parsnip      1.2.0      βœ” workflowsets 1.0.1 
#> βœ” purrr        1.0.2      βœ” yardstick    1.3.0
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> βœ– purrr::discard() masks scales::discard()
#> βœ– dplyr::filter()  masks stats::filter()
#> βœ– dplyr::lag()     masks stats::lag()
#> βœ– recipes::step()  masks stats::step()
#> β€’ Search for functions across packages at https://www.tidymodels.org/find/
library(ranger)

ova_metrics <- function(x, truth, estimate, metric_set) {
  x %>% 
    dplyr::mutate(
      truth_ova = purrr::map({{ truth }}, ~ {
        case_when(
          levels({{ truth }}) %in% .x ~ as.character(.x),
          is.na(.x) ~ NA_character_,
          .default = "class_0"
        ) %>% 
          rlang::set_names(levels({{ truth }}))
      }),
      estimate_ova = purrr::map({{ estimate }}, ~ {
        case_when(
          levels({{ estimate }}) %in% .x ~ as.character(.x),
          is.na(.x) ~ NA_character_,
          .default = "class_0"
        ) %>% 
          rlang::set_names(levels({{ estimate }}))
      })
    ) %>%
    tidyr::unnest_longer(col = c(truth_ova, estimate_ova)) %>% 
    dplyr::mutate(class_group = purrr::map2_chr(truth_ova_id, estimate_ova_id, unique)) %>% 
    tidyr::nest(.by = class_group) %>%
    dplyr::mutate(
      data = data %>%
        purrr::map(~ dplyr::mutate(.x, dplyr::across(
          dplyr::matches("ova"),
          ~ factor(.x) %>%
            forcats::fct_expand("class_0") %>%
            forcats::fct_relevel("class_0", after = Inf)
        ))) %>%
        purrr::map(metric_set, truth = truth_ova, estimate = estimate_ova) %>%
        suppressWarnings()
    ) %>%
    tidyr::unnest(cols = data)
}

model_recipe <- recipe(Species ~ ., data = iris)

model_final <- rand_forest(
  mode = "classification",
  engine = "ranger",
  mtry = 3,
  min_n = 20,
  trees = 500
) %>%
  set_engine("ranger", importance = "impurity")

model_final_wf <- workflow() |>
  add_recipe(model_recipe) |>
  add_model(model_final)

set.seed(2024)
train_fit <- fit(model_final_wf, data = iris)

train_predictions <- augment(train_fit, iris)
train_predictions %>% conf_mat(Species, .pred_class)
#>             Truth
#> Prediction   setosa versicolor virginica
#>   setosa         50          0         0
#>   versicolor      0         47         1
#>   virginica       0          3        49

per_class_mset <-
  metric_set(accuracy, sensitivity, specificity, f_meas, bal_accuracy, kap)

train_predictions %>% 
  per_class_mset(truth = Species, estimate = .pred_class)
#> # A tibble: 6 Γ— 3
#>   .metric      .estimator .estimate
#>   <chr>        <chr>          <dbl>
#> 1 accuracy     multiclass     0.973
#> 2 sensitivity  macro          0.973
#> 3 specificity  macro          0.987
#> 4 f_meas       macro          0.973
#> 5 bal_accuracy macro          0.98 
#> 6 kap          multiclass     0.96

train_predictions %>% 
  ova_metrics(truth = Species, estimate = .pred_class, metric_set = per_class_mset)
#> # A tibble: 18 Γ— 4
#>    class_group .metric      .estimator .estimate
#>    <chr>       <chr>        <chr>          <dbl>
#>  1 setosa      accuracy     binary         1    
#>  2 setosa      sensitivity  binary         1    
#>  3 setosa      specificity  binary         1    
#>  4 setosa      f_meas       binary         1    
#>  5 setosa      bal_accuracy binary         1    
#>  6 setosa      kap          binary         1    
#>  7 versicolor  accuracy     binary         0.973
#>  8 versicolor  sensitivity  binary         0.94 
#>  9 versicolor  specificity  binary         0.99 
#> 10 versicolor  f_meas       binary         0.959
#> 11 versicolor  bal_accuracy binary         0.965
#> 12 versicolor  kap          binary         0.939
#> 13 virginica   accuracy     binary         0.973
#> 14 virginica   sensitivity  binary         0.98 
#> 15 virginica   specificity  binary         0.97 
#> 16 virginica   f_meas       binary         0.961
#> 17 virginica   bal_accuracy binary         0.975
#> 18 virginica   kap          binary         0.941

Created on 2024-04-02 by the reprex package (v0.3.0)