tidymodels / workflows

Modeling Workflows
https://workflows.tidymodels.org/
Other
207 stars 23 forks source link

Can't augment() a workflows where recipe blueprint has sparse composition #148

Closed EmilHvitfeldt closed 2 years ago

EmilHvitfeldt commented 2 years ago
library(tidymodels)
library(hardhat)

sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")

rec_spec <-
  recipe(mpg ~ ., data = mtcars)

xgb_spec <-
  boost_tree(mode = "regression") %>%
  set_engine("xgboost")

wf_sparse <- 
  workflow() %>%
  add_recipe(rec_spec, blueprint = sparse_bp) %>%
  add_model(xgb_spec)

fit_sparse <- fit(wf_sparse, mtcars)

augment(fit_sparse, new_data = mtcars)
#> Error in `dplyr::bind_cols()`:
#> ! Input must be a vector, not a <dgCMatrix> object.

Created on 2022-03-25 by the reprex package (v2.0.1)

juliasilge commented 2 years ago

You can augment() with composition = "matrix" but it is... not right:

library(tidymodels)
library(hardhat)

sparse_bp <- default_recipe_blueprint(composition = "matrix")

rec_spec <-
  recipe(mpg ~ ., data = mtcars)

xgb_spec <-
  boost_tree(mode = "regression") %>%
  set_engine("xgboost")

wf_sparse <- 
  workflow() %>%
  add_recipe(rec_spec, blueprint = sparse_bp) %>%
  add_model(xgb_spec)

fit_sparse <- fit(wf_sparse, mtcars)

augment(fit_sparse, new_data = mtcars)
#> New names:
#> * cyl -> cyl...2
#> * disp -> disp...3
#> * hp -> hp...4
#> * drat -> drat...5
#> * wt -> wt...6
#> * ...
#> # A tibble: 32 × 22
#>      mpg cyl...2 disp...3 hp...4 drat...5 wt...6 qsec...7 vs...8 am...9
#>  * <dbl>   <dbl>    <dbl>  <dbl>    <dbl>  <dbl>    <dbl>  <dbl>  <dbl>
#>  1  21         6     160     110     3.9    2.62     16.5      0      1
#>  2  21         6     160     110     3.9    2.88     17.0      0      1
#>  3  22.8       4     108      93     3.85   2.32     18.6      1      1
#>  4  21.4       6     258     110     3.08   3.22     19.4      1      0
#>  5  18.7       8     360     175     3.15   3.44     17.0      0      0
#>  6  18.1       6     225     105     2.76   3.46     20.2      1      0
#>  7  14.3       8     360     245     3.21   3.57     15.8      0      0
#>  8  24.4       4     147.     62     3.69   3.19     20        1      0
#>  9  22.8       4     141.     95     3.92   3.15     22.9      1      0
#> 10  19.2       6     168.    123     3.92   3.44     18.3      1      0
#> # … with 22 more rows, and 13 more variables: gear...10 <dbl>, carb...11 <dbl>,
#> #   cyl...12 <dbl>, disp...13 <dbl>, hp...14 <dbl>, drat...15 <dbl>,
#> #   wt...16 <dbl>, qsec...17 <dbl>, vs...18 <dbl>, am...19 <dbl>,
#> #   gear...20 <dbl>, carb...21 <dbl>, .pred <dbl>

Created on 2022-03-25 by the reprex package (v2.0.1)

github-actions[bot] commented 2 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.