tidymodels / hardhat

Construct Modeling Packages
https://hardhat.tidymodels.org
Other
103 stars 17 forks source link

Multi-horizon forecast / multi-output predict support by `spurce_` #161

Closed cregouby closed 2 years ago

cregouby commented 3 years ago

Feature

In situations when models prediction is a multi-horizon forecast or a multi-output prediction, like in

Reprex

The ReprEx is the readme of mlverse/tft package :

remotes::install_github("mlverse/tft")
library(tft)
library(rsample)
library(recipes)
library(yardstick)
set.seed(1)

data("vic_elec", package = "tsibbledata")
vic_elec <- vic_elec[1:256,] %>% 
  mutate(Location = as.factor("Victoria")) 
vic_elec_split <- initial_time_split(vic_elec, prop=3/4, lag=96)

vic_elec_train <- training(vic_elec_split)
vic_elec_test <- testing(vic_elec_split)

rec <- recipe(Demand ~ ., data = vic_elec_train) %>%
  update_role(Date, new_role="id") %>%
  update_role(Time, new_role="time") %>%
  update_role(Temperature, new_role="observed_input") %>%
  update_role(Holiday, new_role="known_input") %>%
  update_role(Location, new_role="static_input") %>%
  step_normalize(all_numeric(), -all_outcomes())

fit <- tft_fit(rec, vic_elec_train, epochs = 3, batch_size=100, total_time_steps=12, num_encoder_steps=10, verbose=TRUE)

yhat <- predict(fit, rec, vic_elec_test)

expected result :

> yhat
# A tibble: 149 x 1
         .pred
   <dbl[,2,1]>
 1      22.5 …
 2      22.5 …
 3      22.5 …
 4      22.5 …
 5      22.5 …
 6      22.5 …
 7      22.5 …
 8      22.5 …
 9      22.5 …
10      22.5 …
# … with 139 more rows
> head(yhat$.pred)
, , 1

           [,1]     [,2]
  [1,] 22.53617 22.53241
  [2,] 22.53617 22.53241
  [3,] 22.53617 22.53241
  [4,] 22.53617 22.53241
  [5,] 22.53617 22.53241
  [6,] 22.53617 22.53241

current result

# deconstructing the predict.tft_fit function
processed <- tft:::batch_data(recipe=rec, df=vic_elec_test,
                          total_time_steps = fit$fit$config$total_time_steps,
                          device = fitj$fit$config$device)
p <- tft:::predict_impl(fit, rec, processed)

p <- hardhat::spruce_numeric(p$to(device="cpu") %>% as.array)
## Erreur : `pred` must be a numeric vector, not a numeric matrix/array.
## Run `rlang::last_error()` to see where the error occurred.
rlang::last_error()
## <error/rlang_error>
##`pred` must be a numeric vector, not a numeric matrix/array.
Backtrace:
 1. hardhat::spruce_numeric(p$to(device = "cpu") %>% as.array)
 2. hardhat:::validate_not_matrix(pred)
 3. hardhat:::glubort("`pred` must be a numeric vector, not a numeric matrix/array.")
Run `rlang::last_trace()` to see the full context.
DavisVaughan commented 3 years ago

What about spruce_class()? Is there a multi outcome extension of that? If so, it wouldn't be stored in a matrix since each of the outcomes would be factors, and you can't have a factor matrix.

And spruce_prob()? Since that already returns a data frame with multiple columns, I'm not sure what the multi outcome extension would look like (or what the arguments would be).

Maybe the outputs would look something like this? We generally prefer data frames over matrices in tidymodels, so the multiple numeric outcomes would be multiple numeric columns, not a matrix-column.

library(tibble)

# spruce_numeric_multi()
tibble(
  .pred_1 = c(1, 2, 1),
  .pred_2 = c(2, 3, 1)
)
#> # A tibble: 3 x 2
#>   .pred_1 .pred_2
#>     <dbl>   <dbl>
#> 1       1       2
#> 2       2       3
#> 3       1       1

# spruce_class_multi()
tibble(
  .pred_class_1 = factor(c("a", "b", "a")),
  .pred_class_2 = factor(c("c", "d", "a"))
)
#> # A tibble: 3 x 2
#>   .pred_class_1 .pred_class_2
#>   <fct>         <fct>        
#> 1 a             c            
#> 2 b             d            
#> 3 a             a

# spruce_prob_multi()
prob <- tibble(
  .pred_1 = tibble(.pred_lvl1 = c(.3, .4), .pred_lvl2 = c(.7, .6)),
  .pred_2 = tibble(.pred_lvl1 = c(.1, .3), .pred_lvl2 = c(.7, .5), .pred_lvl3 = c(.2, .2)),
)

prob
#> # A tibble: 2 x 2
#>   .pred_1$.pred_lvl1 $.pred_lvl2 .pred_2$.pred_lvl1 $.pred_lvl2 $.pred_lvl3
#>                <dbl>       <dbl>              <dbl>       <dbl>       <dbl>
#> 1                0.3         0.7                0.1         0.7         0.2
#> 2                0.4         0.6                0.3         0.5         0.2

prob$.pred_1
#> # A tibble: 2 x 2
#>   .pred_lvl1 .pred_lvl2
#>        <dbl>      <dbl>
#> 1        0.3        0.7
#> 2        0.4        0.6
prob$.pred_2
#> # A tibble: 2 x 3
#>   .pred_lvl1 .pred_lvl2 .pred_lvl3
#>        <dbl>      <dbl>      <dbl>
#> 1        0.1        0.7        0.2
#> 2        0.3        0.5        0.2

Created on 2021-07-13 by the reprex package (v2.0.0)

cregouby commented 3 years ago

That seems a perfect fit !

juliasilge commented 2 years ago

Thanks for the discussion @cregouby! 🙌

github-actions[bot] commented 2 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.