tidymodels / hardhat

Construct Modeling Packages
https://hardhat.tidymodels.org
Other
101 stars 15 forks source link

multi-outcomes support for `spruce_prob_multi` shall clarify input format for multiple `pred_levels` #223

Closed cregouby closed 1 year ago

cregouby commented 1 year ago

Feature (this is a follow-up of #161)

In situations when models is multi-outcomes like prediction is a multi-horizon forecast or a multi-output prediction in

Reprex

see the Reprex in #161

Proposed API

The API proposed by @DavisVaughan in #161 is fine for me:

library(tibble)

# spruce_numeric_multi() expected output
tibble(
  .pred_1 = c(1, 2, 1),
  .pred_2 = c(2, 3, 1)
)
#> # A tibble: 3 x 2
#>   .pred_1 .pred_2
#>     <dbl>   <dbl>
#> 1       1       2
#> 2       2       3
#> 3       1       1

# spruce_class_multi() expected output
tibble(
  .pred_class_1 = factor(c("a", "b", "a")),
  .pred_class_2 = factor(c("c", "d", "a"))
)
#> # A tibble: 3 x 2
#>   .pred_class_1 .pred_class_2
#>   <fct>         <fct>        
#> 1 a             c            
#> 2 b             d            
#> 3 a             a

But the API for spruce_prob_multi() need some clarification :

# spruce_prob_multi() expected output
prob <- tibble(
  .pred_1 = tibble(.pred_lvl1 = c(.3, .4), .pred_lvl2 = c(.7, .6)),
  .pred_2 = tibble(.pred_lvl1 = c(.1, .3), .pred_lvl2 = c(.7, .5), .pred_lvl3 = c(.2, .2)),
)

prob
#> # A tibble: 2 x 2
#>   .pred_1$.pred_lvl1 $.pred_lvl2 .pred_2$.pred_lvl1 $.pred_lvl2 $.pred_lvl3
#>                <dbl>       <dbl>              <dbl>       <dbl>       <dbl>
#> 1                0.3         0.7                0.1         0.7         0.2
#> 2                0.4         0.6                0.3         0.5         0.2

prob$.pred_1
#> # A tibble: 2 x 2
#>   .pred_lvl1 .pred_lvl2
#>        <dbl>      <dbl>
#> 1        0.3        0.7
#> 2        0.4        0.6
prob$.pred_2
#> # A tibble: 2 x 3
#>   .pred_lvl1 .pred_lvl2 .pred_lvl3
#>        <dbl>      <dbl>      <dbl>
#> 1        0.1        0.7        0.2
#> 2        0.3        0.5        0.2

This opens questions for the inputs of the spruce_prob_multi() functions.

My initial thought was to generalize the spruce_prob() inputs but this fails with the example provided where the 2 outcomes have respectively 2 and 3 levels (which is an expected situation) as we cannot have a dataframe of the pred_levels vector of unequal size :

#' @param pred_levels_df,prob_matrix_df (`type = "prob"`)
#' - `pred_levels_df` should be a data-frame of `pred_levels` columns for multi-outcomes predictions.
#' - `prob_matrix_df` should be a data-frame of `prob_matrix` data-frames for multi-outcomes predictions.
#' Each of the `prob_matrix` data-frame should be coercible to a matrix
spruce_prob_multi <- function(pred_levels_df, prob_matrix_df) {
...
}

As there is low chance of that the multi-outcomes factors may have the same number of levels. So there is no chance to make them all reside in a pred_levels_df input:

>   pred_levels_df <- tibble(day = letters[1:7],
+                            month = letters[2:13])
Error in `tibble()`:
! Tibble columns must have compatible sizes.
• Size 5: Existing data.
• Size 12: Column `month`.
ℹ Only values of size one are recycled.
Run `rlang::last_trace()` to see where the error occurred.

What do you think could be the input type for prob_levels_df

DavisVaughan commented 1 year ago

What if the inputs for spruce_prob_multi() were expected to be the outputs from spruce_prob()?

Like:

spruce_prob_multi(
  spruce_prob(c("a", "b"), matrix(c(.3, .7, .4, .6), nrow = 2, byrow = TRUE)),
  foo = spruce_prob(c("a", "b", "c"), matrix(c(.2, .7, .1, .2, .6, .2), nrow = 2, byrow = TRUE))
)

So spruce_prob_multi() would take ..., and if the dots are named like with foo above then that would become the suffix after .pred_{suffix}? If no name is given it would use the position like 1 or 2.


It might make sense to make all of the spruce_*_multi() functions take ... in this way:

That way they all have a somewhat consistent interface and we can apply the ... suffix trick to all 3

github-actions[bot] commented 1 year ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.