Closed cregouby closed 2 years ago
What about spruce_class()
? Is there a multi outcome extension of that? If so, it wouldn't be stored in a matrix since each of the outcomes would be factors, and you can't have a factor matrix.
And spruce_prob()
? Since that already returns a data frame with multiple columns, I'm not sure what the multi outcome extension would look like (or what the arguments would be).
Maybe the outputs would look something like this? We generally prefer data frames over matrices in tidymodels, so the multiple numeric outcomes would be multiple numeric columns, not a matrix-column.
library(tibble)
# spruce_numeric_multi()
tibble(
.pred_1 = c(1, 2, 1),
.pred_2 = c(2, 3, 1)
)
#> # A tibble: 3 x 2
#> .pred_1 .pred_2
#> <dbl> <dbl>
#> 1 1 2
#> 2 2 3
#> 3 1 1
# spruce_class_multi()
tibble(
.pred_class_1 = factor(c("a", "b", "a")),
.pred_class_2 = factor(c("c", "d", "a"))
)
#> # A tibble: 3 x 2
#> .pred_class_1 .pred_class_2
#> <fct> <fct>
#> 1 a c
#> 2 b d
#> 3 a a
# spruce_prob_multi()
prob <- tibble(
.pred_1 = tibble(.pred_lvl1 = c(.3, .4), .pred_lvl2 = c(.7, .6)),
.pred_2 = tibble(.pred_lvl1 = c(.1, .3), .pred_lvl2 = c(.7, .5), .pred_lvl3 = c(.2, .2)),
)
prob
#> # A tibble: 2 x 2
#> .pred_1$.pred_lvl1 $.pred_lvl2 .pred_2$.pred_lvl1 $.pred_lvl2 $.pred_lvl3
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.3 0.7 0.1 0.7 0.2
#> 2 0.4 0.6 0.3 0.5 0.2
prob$.pred_1
#> # A tibble: 2 x 2
#> .pred_lvl1 .pred_lvl2
#> <dbl> <dbl>
#> 1 0.3 0.7
#> 2 0.4 0.6
prob$.pred_2
#> # A tibble: 2 x 3
#> .pred_lvl1 .pred_lvl2 .pred_lvl3
#> <dbl> <dbl> <dbl>
#> 1 0.1 0.7 0.2
#> 2 0.3 0.5 0.2
Created on 2021-07-13 by the reprex package (v2.0.0)
That seems a perfect fit !
Thanks for the discussion @cregouby! 🙌
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.
Feature
In situations when models prediction is a multi-horizon forecast or a multi-output prediction, like in
hardhat::spurce_
commands.Reprex
The ReprEx is the readme of mlverse/tft package :
expected result :
current result