tidymodels / orbital

Turn Tidymodels Workflows Into Series of Equations
https://orbital.tidymodels.org
Other
18 stars 1 forks source link

Using {bundle} package with xgboost #61

Closed jrosell closed 1 week ago

jrosell commented 1 week ago

I'm having trouble when using {xgboost} with the {bundle} package in {orbital}. The strange thing is that other models work with {bundle} but not {xgboost}. TBH, I'm not sure if this bug report should be here at {orbital} or {tidypredict} or {bundle}.

Here an example:

tictoc::tic()
library(tidyverse)
library(tidymodels)
library(bonsai)
library(rules)
#> 
#> Attaching package: 'rules'
#> The following object is masked from 'package:dials':
#> 
#>     max_rules
library(stacks)
library(plotly)
#> 
#> Attaching package: 'plotly'
#> The following object is masked from 'package:ggplot2':
#> 
#>     last_plot
#> The following object is masked from 'package:stats':
#> 
#>     filter
#> The following object is masked from 'package:graphics':
#> 
#>     layout
library(orbital)
library(tidypredict)
library(future)

GRID_SIZE <- 10
set.seed(1234)
data(penguins, package = "modeldata")
penguins_split <- initial_split(drop_na(penguins, body_mass_g))
penguins_train <- training(penguins_split)
penguins_test <- testing(penguins_split)
penguins_fold <- vfold_cv(penguins_train, v  = 3)
plan(multisession)

rec_spec <-
  recipe(body_mass_g ~ ., data = penguins_train) |>
  step_unknown(all_nominal_predictors()) |>
  step_impute_median(all_numeric_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  step_nzv(all_predictors()) |>
  step_scale(all_numeric_predictors()) |>
  step_center(all_numeric_predictors()) |>
  step_corr(all_predictors(), threshold = 0.5)

glm_spec <-  linear_reg() |>
  set_mode("regression") |>
  set_engine("glm")
glm_res <- workflow(rec_spec, glm_spec) |> 
  tune_grid(penguins_fold, control = control_stack_grid(), grid = GRID_SIZE)
#> Warning: No tuning parameters have been detected, performance will be evaluated
#> using the resamples with no tuning. Did you want to [tune()] parameters?

xgb_spec <- boost_tree(mode = "regression", engine = "xgboost") |> 
  set_args(
    trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
    loss_reduction = tune(), sample_size = tune(), mtry = tune(), counts = FALSE
  )
xgb_res <- workflow(rec_spec, xgb_spec) |> 
  tune_grid(
    penguins_fold,
    param_info = extract_parameter_set_dials(xgb_spec) |>
      update(mtry = mtry_prop()),
    control = control_stack_grid(),
    grid = GRID_SIZE
  )
#> ! Fold1: internal:
#>   A correlation computation is required, but `estimate` is constant and ...
#>   standard deviation, resulting in a divide by 0 error. `NA` will be ret...
#> ! Fold2: internal:
#>   A correlation computation is required, but `estimate` is constant and ...
#>   standard deviation, resulting in a divide by 0 error. `NA` will be ret...
#> ! Fold3: internal:
#>   A correlation computation is required, but `estimate` is constant and ...
#>   standard deviation, resulting in a divide by 0 error. `NA` will be ret...

tree_spec <- decision_tree(tree_depth = tune(), min_n = tune()) |>
  set_mode("regression") |>
  set_engine("partykit")
tree_res <- workflow(rec_spec, tree_spec) |> 
  tune_grid(penguins_fold, control = control_stack_grid(), grid = GRID_SIZE)

ranger_spec <- rand_forest(mtry = tune(), trees = tune(), min_n = tune()) |>
  set_mode("regression") |>
  set_engine("ranger")
ranger_res <- workflow(rec_spec, ranger_spec) |> 
  tune_grid(penguins_fold, control = control_stack_grid(), grid = GRID_SIZE)
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ! Fold1: preprocessor 1/1, model 9/10:
#>   ! 5 columns were requested but there were 4 predictors in the data.
#>   ℹ 4 predictors will be used.
#> ! Fold1: preprocessor 1/1, model 10/10:
#>   ! 5 columns were requested but there were 4 predictors in the data.
#>   ℹ 4 predictors will be used.
#> ! Fold3: preprocessor 1/1, model 7/10:
#>   ! 4 columns were requested but there were 3 predictors in the data.
#>   ℹ 3 predictors will be used.
#> ! Fold3: preprocessor 1/1, model 8/10:
#>   ! 4 columns were requested but there were 3 predictors in the data.
#>   ℹ 3 predictors will be used.
#> ! Fold3: preprocessor 1/1, model 9/10:
#>   ! 5 columns were requested but there were 3 predictors in the data.
#>   ℹ 3 predictors will be used.
#> ! Fold3: preprocessor 1/1, model 10/10:
#>   ! 5 columns were requested but there were 3 predictors in the data.
#>   ℹ 3 predictors will be used.

mars_spec <- mars(mode = "regression", engine = "earth") |> 
  set_args(
    num_terms = tune(), prod_degree = tune(), prune_method =  tune()
  ) 
mars_res <- workflow(rec_spec, mars_spec) |> 
  tune_grid(penguins_fold, control = control_stack_grid(), grid = GRID_SIZE)
#> x Fold1: preprocessor 1/1, model 4/8: Error: the nfold argument must be specified when pmethod="cv"
#> x Fold2: preprocessor 1/1, model 4/8: Error: the nfold argument must be specified when pmethod="cv"
#> x Fold3: preprocessor 1/1, model 4/8: Error: the nfold argument must be specified when pmethod="cv"

cubist_spec <- cubist_rules(mode = "regression", engine = "Cubist") |> 
  set_args(committees = tune(), neighbors = tune())
cubist_res <- workflow(rec_spec, cubist_spec) |> 
  tune_grid(penguins_fold, control = control_stack_grid(), grid = GRID_SIZE)

results <- as_workflow_set(
  tree_res = tree_res,
  glm_res = glm_res,
  xgb_res = xgb_res,
  ranger_res = ranger_res,
  mars_res = mars_res,
  cubist_res = cubist_res
)

results_ranked <- rank_results(results, rank_metric = "rmse")
results_ranked
#> # A tibble: 108 × 9
#>    wflow_id   .config     .metric    mean std_err     n preprocessor model  rank
#>    <chr>      <chr>       <chr>     <dbl>   <dbl> <int> <chr>        <chr> <int>
#>  1 xgb_res    Preprocess… rmse    455.    106.        3 recipe       boos…     1
#>  2 xgb_res    Preprocess… rsq       0.662   0.124     3 recipe       boos…     1
#>  3 xgb_res    Preprocess… rmse    458.    107.        3 recipe       boos…     2
#>  4 xgb_res    Preprocess… rsq       0.656   0.129     3 recipe       boos…     2
#>  5 ranger_res Preprocess… rmse    460.    104.        3 recipe       rand…     3
#>  6 ranger_res Preprocess… rsq       0.652   0.124     3 recipe       rand…     3
#>  7 ranger_res Preprocess… rmse    462.    105.        3 recipe       rand…     4
#>  8 ranger_res Preprocess… rsq       0.649   0.125     3 recipe       rand…     4
#>  9 ranger_res Preprocess… rmse    463.    104.        3 recipe       rand…     5
#> 10 ranger_res Preprocess… rsq       0.650   0.124     3 recipe       rand…     5
#> # ℹ 98 more rows
# results |> autoplot(metric = "rmse", id = "xgb_res")

best_results <- 
  results |> 
  extract_workflow_set_result(results_ranked[[1,"wflow_id"]])

results |> 
  extract_workflow(results_ranked[[1,"wflow_id"]]) |> 
  finalize_workflow(best_results  |> show_best(metric = "rmse", n = 1)) |>
  fit(data = penguins_train) |> 
  augment(penguins_test) |> 
  rmse(body_mass_g, .pred)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard        351.

# Saving, loading and using orbital with partykit
tree_fit <- workflow(rec_spec, tree_spec) |> 
  finalize_workflow(tree_res  |> show_best(metric = "rmse", n = 1)) |> 
  fit(data = penguins_train)
penguins_partykit <- bundle::bundle(tree_fit)
penguins_partykit |> write_rds("penguins_partykit.rds")
penguins_partykit <- read_rds("penguins_partykit.rds")
penguins_fit <- bundle::unbundle(penguins_partykit)
orbital_obj <- orbital(penguins_fit)
orbital_obj
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • species = dplyr::if_else(is.na(species), "unknown", species)
#> • sex = dplyr::if_else(is.na(sex), "unknown", sex)
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.3, bill_lengt ...
#> • bill_depth_mm = dplyr::if_else(is.na(bill_depth_mm), 17.6, bill_depth_m ...
#> • species_Chinstrap = as.numeric(species == "Chinstrap")
#> • sex_male = as.numeric(sex == "male")
#> • bill_length_mm = bill_length_mm / 5.537645
#> • bill_depth_mm = bill_depth_mm / 1.969433
#> • species_Chinstrap = species_Chinstrap / 0.3972177
#> • sex_male = sex_male / 0.5008418
#> • bill_length_mm = bill_length_mm - 7.892217
#> • bill_depth_mm = bill_depth_mm - 8.783858
#> • species_Chinstrap = species_Chinstrap - 0.4917015
#> • sex_male = sex_male - 1.021717
#> • .pred = case_when(bill_depth_mm <= -0.2027075 & bill_length_mm <= -0.50 ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 15 equations in total.

# Saving, loading and using orbital with ranger
ranger_fit <- workflow(rec_spec, ranger_spec) |> 
  finalize_workflow(ranger_res  |> show_best(metric = "rmse", n = 1)) |> 
  fit(data = penguins_train)
penguins_ranger <- bundle::bundle(ranger_fit)
penguins_ranger |> write_rds("penguins_ranger.rds")
penguins_ranger <- read_rds("penguins_ranger.rds")
penguins_fit <- bundle::unbundle(penguins_ranger)
orbital_obj <- orbital(penguins_fit)
orbital_obj
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • species = dplyr::if_else(is.na(species), "unknown", species)
#> • island = dplyr::if_else(is.na(island), "unknown", island)
#> • sex = dplyr::if_else(is.na(sex), "unknown", sex)
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.3, bill_lengt ...
#> • bill_depth_mm = dplyr::if_else(is.na(bill_depth_mm), 17.6, bill_depth_m ...
#> • species_Chinstrap = as.numeric(species == "Chinstrap")
#> • island_Torgersen = as.numeric(island == "Torgersen")
#> • sex_male = as.numeric(sex == "male")
#> • bill_length_mm = bill_length_mm / 5.537645
#> • bill_depth_mm = bill_depth_mm / 1.969433
#> • species_Chinstrap = species_Chinstrap / 0.3972177
#> • island_Torgersen = island_Torgersen / 0.3523164
#> • sex_male = sex_male / 0.5008418
#> • bill_length_mm = bill_length_mm - 7.892217
#> • bill_depth_mm = bill_depth_mm - 8.783858
#> • species_Chinstrap = species_Chinstrap - 0.4917015
#> • island_Torgersen = island_Torgersen - 0.4102314
#> • sex_male = sex_male - 1.021717
#> • .pred = list(case_when(bill_depth_mm < -0.5835277 & sex_male < -0.02339 ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 19 equations in total.

# Saving, loading and using orbital with xgb (without bundle)
xgb_fit <- workflow(rec_spec, xgb_spec) |> 
  finalize_workflow(xgb_res  |> show_best(metric = "rmse", n = 1)) |> 
  fit(data = penguins_train)
xgb_fit |> write_rds("penguins_xgb.rds")
penguins_fit <- read_rds("penguins_xgb.rds")
orbital_obj <- orbital(penguins_fit)
orbital_obj
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • species = dplyr::if_else(is.na(species), "unknown", species)
#> • island = dplyr::if_else(is.na(island), "unknown", island)
#> • sex = dplyr::if_else(is.na(sex), "unknown", sex)
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.3, bill_lengt ...
#> • bill_depth_mm = dplyr::if_else(is.na(bill_depth_mm), 17.6, bill_depth_m ...
#> • species_Chinstrap = as.numeric(species == "Chinstrap")
#> • island_Torgersen = as.numeric(island == "Torgersen")
#> • sex_male = as.numeric(sex == "male")
#> • bill_length_mm = bill_length_mm / 5.537645
#> • bill_depth_mm = bill_depth_mm / 1.969433
#> • species_Chinstrap = species_Chinstrap / 0.3972177
#> • island_Torgersen = island_Torgersen / 0.3523164
#> • sex_male = sex_male / 0.5008418
#> • bill_length_mm = bill_length_mm - 7.892217
#> • bill_depth_mm = bill_depth_mm - 8.783858
#> • species_Chinstrap = species_Chinstrap - 0.49170.5
#> • island_Torgersen = island_Torgersen - 0.4102314
#> • sex_male = sex_male - 1.021717
#> • .pred = 0 + case_when((bill_depth_mm < -0.4565876 | is.na(bill_depth_mm ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 19 equations in total.

# Saving, loading and using orbital with xgb (with bundle)
tryCatch({
  xgb_fit <- workflow(rec_spec, xgb_spec) |> 
    finalize_workflow(xgb_res  |> show_best(metric = "rmse", n = 1)) |> 
    fit(data = penguins_train)
  penguins_xgb <- bundle::bundle(xgb_fit)
  penguins_xgb |> write_rds("penguins_xgb.rds")
  penguins_xgb <- read_rds("penguins_xgb.rds")
  penguins_fit <- bundle::unbundle(penguins_xgb)
  orbital_obj <- orbital(penguins_fit)
  orbital_obj
  },
  error = \(e) print(e)
)
#> <simpleError in data.frame(Feature = as.character(0:(length(feature_names) -     1)), feature_name = feature_names, stringsAsFactors = FALSE): arguments imply differing number of rows: 2, 0>

tictoc::toc()
#> 4807.046 sec elapsed
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 (2024-06-14)
#>  os       Ubuntu 22.04.5 LTS
#>  system   x86_64, linux-gnu
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/Madrid
#>  date     2024-11-04
#>  pandoc   2.9.2.1 @ /bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version     date (UTC) lib source
#>  backports      1.5.0       2024-05-23 [1] CRAN (R 4.4.0)
#>  bonsai       * 0.3.1.9000  2024-11-01 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  broom        * 1.0.7       2024-09-26 [1] RSPM (R 4.4.0)
#>  bundle         0.1.1       2023-09-09 [1] CRAN (R 4.4.0)
#>  butcher        0.3.4       2024-04-11 [1] CRAN (R 4.4.0)
#>  class          7.3-22      2023-05-03 [4] CRAN (R 4.3.1)
#>  cli            3.6.3.9000  2024-10-26 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  codetools      0.2-19      2023-02-01 [4] CRAN (R 4.2.2)
#>  Cubist         0.4.2.1     2023-03-09 [1] CRAN (R 4.4.0)
#>  data.table     1.16.99     2024-10-26 [1] https://rdatatable.r-universe.dev (R 4.4.1)
#>  dials        * 1.3.0.9000  2024-10-26 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  DiceDesign     1.10        2023-12-07 [1] CRAN (R 4.4.0)
#>  digest         0.6.37.1    2024-10-26 [1] https://eddelbuettel.r-universe.dev (R 4.4.1)
#>  doFuture       1.0.1       2023-12-20 [1] CRAN (R 4.4.0)
#>  dplyr        * 1.1.4.9000  2024-10-26 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  earth          5.3.4       2024-10-05 [1] CRAN (R 4.4.1)
#>  evaluate       1.0.0       2024-09-17 [1] RSPM (R 4.4.0)
#>  farver         2.1.2.9000  2024-10-26 [1] https://thomasp85.r-universe.dev (R 4.4.1)
#>  fastmap        1.2.0       2024-05-15 [1] RSPM (R 4.4.0)
#>  forcats      * 1.0.0       2023-01-29 [1] CRAN (R 4.4.0)
#>  foreach        1.5.2       2022-02-02 [1] CRAN (R 4.4.0)
#>  Formula        1.2-5       2023-02-24 [1] CRAN (R 4.4.0)
#>  fs             1.6.4       2024-04-25 [1] CRAN (R 4.4.0)
#>  furrr          0.3.1       2022-08-15 [1] CRAN (R 4.4.0)
#>  future       * 1.34.0      2024-07-29 [1] CRAN (R 4.4.1)
#>  future.apply   1.11.2      2024-03-28 [1] CRAN (R 4.4.0)
#>  generics       0.1.3.9000  2024-10-26 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  ggplot2      * 3.5.1.9000  2024-10-26 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  globals        0.16.3      2024-03-08 [1] CRAN (R 4.4.0)
#>  glue           1.8.0.9000  2024-11-01 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  gower          1.0.1       2022-12-22 [1] CRAN (R 4.4.0)
#>  GPfit          1.0-8       2019-02-08 [1] CRAN (R 4.4.0)
#>  gtable         0.3.6.9000  2024-10-26 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  hardhat        1.4.0.9002  2024-10-26 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  hms            1.1.3       2023-03-21 [1] CRAN (R 4.4.0)
#>  htmltools      0.5.8.1     2024-04-04 [1] RSPM (R 4.4.0)
#>  htmlwidgets    1.6.4       2023-12-06 [1] RSPM (R 4.4.0)
#>  httr           1.4.7       2023-08-15 [1] CRAN (R 4.4.0)
#>  infer        * 1.0.7       2024-03-25 [1] CRAN (R 4.4.0)
#>  inum           1.0-5       2023-03-09 [1] CRAN (R 4.4.0)
#>  ipred          0.9-15      2024-07-18 [1] RSPM (R 4.4.0)
#>  iterators      1.0.14      2022-02-05 [1] CRAN (R 4.4.0)
#>  jsonlite       1.8.9       2024-09-20 [1] CRAN (R 4.4.1)
#>  knitr          1.48        2024-07-07 [1] RSPM
#>  lattice        0.22-6      2024-03-20 [1] CRAN (R 4.4.1)
#>  lava           1.8.0       2024-03-05 [1] CRAN (R 4.4.0)
#>  lazyeval       0.2.2       2019-03-15 [1] RSPM (R 4.4.0)
#>  lhs            1.2.0       2024-06-30 [1] RSPM (R 4.4.0)
#>  libcoin        1.0-10      2023-09-27 [1] CRAN (R 4.4.0)
#>  lifecycle      1.0.4       2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv        0.9.1       2024-01-29 [1] CRAN (R 4.4.0)
#>  lubridate    * 1.9.3.9000  2024-10-26 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  magrittr       2.0.3.9000  2024-10-26 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  MASS           7.3-60.2    2024-04-26 [1] RSPM
#>  Matrix         1.7-0       2024-04-26 [1] CRAN (R 4.4.1)
#>  modeldata    * 1.4.0       2024-06-19 [1] RSPM (R 4.4.0)
#>  mvtnorm        1.2-5       2024-05-21 [1] CRAN (R 4.4.0)
#>  nnet           7.3-19      2023-05-03 [4] CRAN (R 4.3.1)
#>  orbital      * 0.2.0.9000  2024-11-04 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  parallelly     1.38.0      2024-07-27 [1] RSPM (R 4.4.0)
#>  parsnip      * 1.2.1.9003  2024-11-04 [1] https://t~
#>  partykit       1.2-20      2023-04-14 [1] CRAN (R 4.4.0)
#>  pillar         1.9.0.9024  2024-10-26 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  pkgconfig      2.0.3       2019-09-22 [1] CRAN (R 4.4.0)
#>  plotly       * 4.10.4.9000 2024-11-04 [1] https://plotly.r-universe.dev (R 4.4.1)
#>  plotmo         3.6.3       2024-02-26 [1] CRAN (R 4.4.0)
#>  plotrix        3.8-4       2023-11-10 [1] CRAN (R 4.4.0)
#>  plyr           1.8.9       2023-10-02 [1] CRAN (R 4.4.0)
#>  prodlim        2024.06.25  2024-06-24 [1] RSPM (R 4.4.0)
#>  purrr        * 1.0.2.9000  2024-10-26 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  R6             2.5.1.9000  2024-10-26 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  ranger         0.16.0      2023-11-12 [1] CRAN (R 4.4.0)
#>  RColorBrewer   1.1-3       2022-04-03 [1] CRAN (R 4.4.0)
#>  Rcpp           1.0.13.4    2024-10-26 [1] https://rcppcore.r-universe.dev (R 4.4.1)
#>  readr        * 2.1.5       2024-01-10 [1] CRAN (R 4.4.0)
#>  recipes      * 1.1.0.9001  2024-10-26 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  reprex         2.1.1       2024-07-06 [1] RSPM (R 4.4.0)
#>  reshape2       1.4.4       2020-04-09 [1] CRAN (R 4.4.0)
#>  rlang          1.1.4.9000  2024-10-26 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  rmarkdown      2.28        2024-08-17 [1] RSPM
#>  rpart          4.1.23      2023-12-05 [4] CRAN (R 4.3.2)
#>  rsample      * 1.2.1.9000  2024-10-26 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  rules        * 1.0.2.9000  2024-11-01 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  scales       * 1.3.0.9000  2024-11-01 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  sessioninfo    1.2.2       2021-12-06 [1] RSPM
#>  sfd            0.1.0.9000  2024-10-26 [1] https://topepo.r-universe.dev (R 4.4.1)
#>  sparsevctrs    0.1.0.9002  2024-10-26 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  stacks       * 1.0.5.9000  2024-11-01 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  stringi        1.8.4.9001  2024-10-26 [1] https://gagolews.r-universe.dev (R 4.4.1)
#>  stringr      * 1.5.1.9000  2024-10-26 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  survival       3.7-0       2024-06-05 [4] CRAN (R 4.4.0)
#>  tibble       * 3.2.1.9032  2024-10-26 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  tictoc         1.2.1       2024-03-18 [1] CRAN (R 4.4.0)
#>  tidymodels   * 1.2.0.9000  2024-11-01 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#>  tidypredict  * 0.5         2023-01-18 [1] CRAN (R 4.4.1)
#>  tidyr        * 1.3.1.9000  2024-10-26 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  tidyselect     1.2.1.9000  2024-10-26 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  tidyverse    * 2.0.0.9000  2024-11-01 [1] https://tidyverse.r-universe.dev (R 4.4.1)
#>  timechange     0.3.0       2024-01-18 [1] CRAN (R 4.4.0)
#>  timeDate       4041.110    2024-09-22 [1] RSPM (R 4.4.0)
#>  tune         * 1.2.1.9000  2024-11-04 [1] https://t~
#>  tzdb           0.4.0       2023-05-12 [1] CRAN (R 4.4.0)
#>  utf8           1.2.4       2023-10-22 [1] CRAN (R 4.4.0)
#>  vctrs          0.6.5.9000  2024-10-26 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  viridisLite    0.4.2       2023-05-02 [1] CRAN (R 4.4.0)
#>  withr          3.0.2.9000  2024-10-31 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  workflows    * 1.1.4.9000  2024-11-04 [1] https://t~
#>  workflowsets * 1.1.0       2024-03-21 [1] CRAN (R 4.4.0)
#>  xfun           0.48        2024-10-03 [1] RSPM (R 4.4.0)
#>  xgboost        1.7.8.1     2024-07-24 [1] RSPM (R 4.4.0)
#>  yaml           2.3.10      2024-07-26 [1] RSPM (R 4.4.0)
#>  yardstick    * 1.3.1.9000  2024-10-26 [1] https://tidymodels.r-universe.dev (R 4.4.1)
#> 
#>  [1] /home/jordi/R/x86_64-pc-linux-gnu-library/4.4
#>  [2] /usr/local/lib/R/site-library
#>  [3] /usr/lib/R/site-library
#>  [4] /usr/lib/R/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────
JavOrraca commented 1 week ago

Not sure if this is related, but I'm having trouble processing an {orbital} object using a trained XGBoost workflow() that was trimmed down with {butcher}. Thanks Emil & team.

EmilHvitfeldt commented 1 week ago

i see you! I'll see if I can take a look at this thursday or friday

EmilHvitfeldt commented 1 week ago

this is indeed a bug. We will see what we can do!

below is a smaller reprex

library(tidymodels)
library(bundle)
library(orbital)

mod <- boost_tree(trees = 5, mtry = 3) %>%
  set_mode("regression") %>%
  set_engine("xgboost") %>%
  fit(mpg ~ ., data = mtcars[1:25,])

orbital(mod)
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • .pred = 0 + case_when((cyl < 7 | is.na(cyl)) ~ 6.432858, cyl >= 7 ~ 4.0 ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 1 equations in total.

mod_bundle <- bundle(mod)
mod_new <- unbundle(mod_bundle)

orbital(mod_new)
#> Error in data.frame(Feature = as.character(0:(length(feature_names) - : arguments imply differing number of rows: 2, 0

# Because

mod$fit$nfeatures
#> [1] 10
mod$fit$feature_names
#>  [1] "cyl"  "disp" "hp"   "drat" "wt"   "qsec" "vs"   "am"   "gear" "carb"

mod_new$fit$nfeatures
#> NULL
mod_new$fit$feature_names
#> NULL
EmilHvitfeldt commented 1 week ago

@JavOrraca

butchering does not appear to xgboost, so there is another issue at hand

library(tidymodels)
library(butcher)
library(orbital)

mod <- boost_tree(trees = 5, mtry = 3) %>%
  set_mode("regression") %>%
  set_engine("xgboost") %>%
  fit(mpg ~ ., data = mtcars[1:25,])

orbital(mod)
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • .pred = 0 + case_when((disp < 266.9 | is.na(disp)) ~ 6.432.58, disp >= ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 1 equations in total.

mod_butcher <- butcher(mod)

orbital(mod_butcher)
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • .pred = 0 + case_when((disp < 266.9 | is.na(disp)) ~ 6.432.58, disp >= ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 1 equations in total.

mod$fit$nfeatures
#> [1] 10
mod$fit$feature_names
#>  [1] "cyl"  "disp" "hp"   "drat" "wt"   "qsec" "vs"   "am"   "gear" "carb"

mod_butcher$fit$nfeatures
#> [1] 10
mod_butcher$fit$feature_names
#>  [1] "cyl"  "disp" "hp"   "drat" "wt"   "qsec" "vs"   "am"   "gear" "carb"

Created on 2024-11-07 with reprex v2.1.0

jrosell commented 1 week ago

Thanks!

jrosell commented 2 days ago

Just to confirm that is fixed with the new version of {bundle}

library(tidymodels)
library(bundle)
library(orbital)

mod <- boost_tree(trees = 5, mtry = 3) %>%
  set_mode("regression") %>%
  set_engine("xgboost") %>%
  fit(mpg ~ ., data = mtcars[1:25,])

orbital(mod)
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • .pred = 0 + case_when((hp < 136.5 | is.na(hp)) ~ 6.432858, hp >= 136.5 ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 1 equations in total.

mod_bundle <- bundle(mod)
mod_new <- unbundle(mod_bundle)

orbital(mod_new)
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • .pred = 0 + case_when((hp < 136.5 | is.na(hp)) ~ 6.432858, hp >= 136.5 ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 1 equations in total.