tidymodels / hardhat

Construct Modeling Packages
https://hardhat.tidymodels.org
Other
103 stars 17 forks source link

Add composition to preprocessor blueprints #150

Closed juliasilge closed 4 years ago

juliasilge commented 4 years ago

This is one of the first pieces needed for supporting sparse data structures in tidymodels.

library(hardhat)
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step

train <- iris[1:100,]
test <- iris[101:150,]

rec <- recipe(Species ~ Sepal.Length + Sepal.Width, train) %>%
  step_log(Sepal.Length)

sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")
sparse_bp
#> Recipe blueprint: 
#>  
#> # Predictors: 0 
#>   # Outcomes: 0 
#>    Intercept: FALSE 
#> Novel Levels: FALSE 
#>  Composition: dgCMatrix

processed <- mold(rec, train, blueprint = sparse_bp)
processed$blueprint
#> Recipe blueprint: 
#>  
#> # Predictors: 2 
#>   # Outcomes: 1 
#>    Intercept: FALSE 
#> Novel Levels: FALSE 
#>  Composition: dgCMatrix

forge(train, blueprint = processed$blueprint)
#> $predictors
#> 83 x 2 sparse Matrix of class "dgCMatrix"
#>       Sepal.Length Sepal.Width
#>  [1,]     1.629241         3.5
#>  [2,]     1.589235         3.0
#>  [3,]     1.547563         3.2
#>  [4,]     1.526056         3.1
#>  [5,]     1.609438         3.6
#>  [6,]     1.686399         3.9
#>  [7,]     1.526056         3.4
#>  [8,]     1.609438         3.4
#>  [9,]     1.481605         2.9
#> [10,]     1.589235         3.1
#> [11,]     1.686399         3.7
#> [12,]     1.568616         3.4
#> [13,]     1.568616         3.0
#> [14,]     1.458615         3.0
#> [15,]     1.757858         4.0
#> [16,]     1.740466         4.4
#> [17,]     1.740466         3.8
#> [18,]     1.629241         3.8
#> [19,]     1.686399         3.4
#> [20,]     1.629241         3.7
#> [21,]     1.526056         3.6
#> [22,]     1.629241         3.3
#> [23,]     1.609438         3.0
#> [24,]     1.648659         3.5
#> [25,]     1.648659         3.4
#> [26,]     1.568616         3.1
#> [27,]     1.648659         4.1
#> [28,]     1.704748         4.2
#> [29,]     1.609438         3.2
#> [30,]     1.704748         3.5
#> [31,]     1.589235         3.6
#> [32,]     1.481605         3.0
#> [33,]     1.629241         3.4
#> [34,]     1.609438         3.5
#> [35,]     1.504077         2.3
#> [36,]     1.481605         3.2
#> [37,]     1.526056         3.2
#> [38,]     1.667707         3.7
#> [39,]     1.609438         3.3
#> [40,]     1.945910         3.2
#> [41,]     1.856298         3.2
#> [42,]     1.931521         3.1
#> [43,]     1.704748         2.3
#> [44,]     1.871802         2.8
#> [45,]     1.740466         2.8
#> [46,]     1.840550         3.3
#> [47,]     1.589235         2.4
#> [48,]     1.887070         2.9
#> [49,]     1.648659         2.7
#> [50,]     1.609438         2.0
#> [51,]     1.774952         3.0
#> [52,]     1.791759         2.2
#> [53,]     1.808289         2.9
#> [54,]     1.722767         2.9
#> [55,]     1.902108         3.1
#> [56,]     1.722767         3.0
#> [57,]     1.757858         2.7
#> [58,]     1.824549         2.2
#> [59,]     1.722767         2.5
#> [60,]     1.774952         3.2
#> [61,]     1.808289         2.8
#> [62,]     1.840550         2.5
#> [63,]     1.856298         2.9
#> [64,]     1.887070         3.0
#> [65,]     1.916923         2.8
#> [66,]     1.902108         3.0
#> [67,]     1.791759         2.9
#> [68,]     1.740466         2.6
#> [69,]     1.704748         2.4
#> [70,]     1.791759         2.7
#> [71,]     1.686399         3.0
#> [72,]     1.791759         3.4
#> [73,]     1.840550         2.3
#> [74,]     1.704748         2.5
#> [75,]     1.704748         2.6
#> [76,]     1.808289         3.0
#> [77,]     1.757858         2.6
#> [78,]     1.609438         2.3
#> [79,]     1.722767         2.7
#> [80,]     1.740466         3.0
#> [81,]     1.740466         2.9
#> [82,]     1.824549         2.9
#> [83,]     1.629241         2.5
#> 
#> $outcomes
#> NULL
#> 
#> $extras
#> $extras$roles
#> NULL
forge(test, blueprint = processed$blueprint)
#> $predictors
#> 44 x 2 sparse Matrix of class "dgCMatrix"
#>       Sepal.Length Sepal.Width
#>  [1,]     1.840550         3.3
#>  [2,]     1.757858         2.7
#>  [3,]     1.960095         3.0
#>  [4,]     1.840550         2.9
#>  [5,]     1.871802         3.0
#>  [6,]     2.028148         3.0
#>  [7,]     1.589235         2.5
#>  [8,]     1.987874         2.9
#>  [9,]     1.902108         2.5
#> [10,]     1.974081         3.6
#> [11,]     1.871802         3.2
#> [12,]     1.856298         2.7
#> [13,]     1.916923         3.0
#> [14,]     1.740466         2.5
#> [15,]     1.757858         2.8
#> [16,]     1.856298         3.2
#> [17,]     2.041220         3.8
#> [18,]     2.041220         2.6
#> [19,]     1.791759         2.2
#> [20,]     1.931521         3.2
#> [21,]     1.722767         2.8
#> [22,]     2.041220         2.8
#> [23,]     1.840550         2.7
#> [24,]     1.902108         3.3
#> [25,]     1.974081         3.2
#> [26,]     1.824549         2.8
#> [27,]     1.808289         3.0
#> [28,]     1.856298         2.8
#> [29,]     1.974081         3.0
#> [30,]     2.001480         2.8
#> [31,]     2.066863         3.8
#> [32,]     1.840550         2.8
#> [33,]     1.808289         2.6
#> [34,]     2.041220         3.0
#> [35,]     1.840550         3.4
#> [36,]     1.856298         3.1
#> [37,]     1.791759         3.0
#> [38,]     1.931521         3.1
#> [39,]     1.902108         3.1
#> [40,]     1.916923         3.2
#> [41,]     1.902108         3.0
#> [42,]     1.840550         2.5
#> [43,]     1.824549         3.4
#> [44,]     1.774952         3.0
#> 
#> $outcomes
#> NULL
#> 
#> $extras
#> $extras$roles
#> NULL

Created on 2020-09-28 by the reprex package (v0.3.0.9001)

juliasilge commented 4 years ago

Related to tidymodels/tidymodels#42

juliasilge commented 4 years ago

Closes #100 eventually

juliasilge commented 4 years ago

I worked more on this today, and now during forge() the recipe is only baked one time on the predictors but everything else wires up correctly. 🎉

library(hardhat)
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step

train <- iris[1:100,]
test <- iris[101:150,]

rec <- recipe(Species ~ Sepal.Length + Sepal.Width, train) %>%
  step_normalize(Sepal.Length)

sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")
sparse_bp
#> Recipe blueprint: 
#>  
#> # Predictors: 0 
#>   # Outcomes: 0 
#>    Intercept: FALSE 
#> Novel Levels: FALSE 
#>  Composition: dgCMatrix

processed <- mold(rec, train, blueprint = sparse_bp)
processed$blueprint
#> Recipe blueprint: 
#>  
#> # Predictors: 2 
#>   # Outcomes: 1 
#>    Intercept: FALSE 
#> Novel Levels: FALSE 
#>  Composition: dgCMatrix

forge(train, blueprint = processed$blueprint)
#> $predictors
#> 100 x 2 sparse Matrix of class "dgCMatrix"
#>        Sepal.Length Sepal.Width
#>   [1,]  -0.57815327         3.5
#>   [2,]  -0.88982620         3.0
#>   [3,]  -1.20149912         3.2
#>   [4,]  -1.35733558         3.1
#>   [5,]  -0.73398974         3.6
#>   [6,]  -0.11064389         3.9
#>   [7,]  -1.35733558         3.4
#>   [8,]  -0.73398974         3.4
#>   [9,]  -1.66900851         2.9
#>  [10,]  -0.88982620         3.1
#>  [11,]  -0.11064389         3.7
#>  [12,]  -1.04566266         3.4
#>  [13,]  -1.04566266         3.0
#>  [14,]  -1.82484497         3.0
#>  [15,]   0.51270196         4.0
#>  [16,]   0.35686550         4.4
#>  [17,]  -0.11064389         3.9
#>  [18,]  -0.57815327         3.5
#>  [19,]   0.35686550         3.8
#>  [20,]  -0.57815327         3.8
#>  [21,]  -0.11064389         3.4
#>  [22,]  -0.57815327         3.7
#>  [23,]  -1.35733558         3.6
#>  [24,]  -0.57815327         3.3
#>  [25,]  -1.04566266         3.4
#>  [26,]  -0.73398974         3.0
#>  [27,]  -0.73398974         3.4
#>  [28,]  -0.42231681         3.5
#>  [29,]  -0.42231681         3.4
#>  [30,]  -1.20149912         3.2
#>  [31,]  -1.04566266         3.1
#>  [32,]  -0.11064389         3.4
#>  [33,]  -0.42231681         4.1
#>  [34,]   0.04519257         4.2
#>  [35,]  -0.88982620         3.1
#>  [36,]  -0.73398974         3.2
#>  [37,]   0.04519257         3.5
#>  [38,]  -0.88982620         3.6
#>  [39,]  -1.66900851         3.0
#>  [40,]  -0.57815327         3.4
#>  [41,]  -0.73398974         3.5
#>  [42,]  -1.51317205         2.3
#>  [43,]  -1.66900851         3.2
#>  [44,]  -0.73398974         3.5
#>  [45,]  -0.57815327         3.8
#>  [46,]  -1.04566266         3.0
#>  [47,]  -0.57815327         3.8
#>  [48,]  -1.35733558         3.2
#>  [49,]  -0.26648035         3.7
#>  [50,]  -0.73398974         3.3
#>  [51,]   2.38273950         3.2
#>  [52,]   1.44772073         3.2
#>  [53,]   2.22690304         3.1
#>  [54,]   0.04519257         2.3
#>  [55,]   1.60355719         2.8
#>  [56,]   0.35686550         2.8
#>  [57,]   1.29188427         3.3
#>  [58,]  -0.88982620         2.4
#>  [59,]   1.75939366         2.9
#>  [60,]  -0.42231681         2.7
#>  [61,]  -0.73398974         2.0
#>  [62,]   0.66853842         3.0
#>  [63,]   0.82437488         2.2
#>  [64,]   0.98021135         2.9
#>  [65,]   0.20102904         2.9
#>  [66,]   1.91523012         3.1
#>  [67,]   0.20102904         3.0
#>  [68,]   0.51270196         2.7
#>  [69,]   1.13604781         2.2
#>  [70,]   0.20102904         2.5
#>  [71,]   0.66853842         3.2
#>  [72,]   0.98021135         2.8
#>  [73,]   1.29188427         2.5
#>  [74,]   0.98021135         2.8
#>  [75,]   1.44772073         2.9
#>  [76,]   1.75939366         3.0
#>  [77,]   2.07106658         2.8
#>  [78,]   1.91523012         3.0
#>  [79,]   0.82437488         2.9
#>  [80,]   0.35686550         2.6
#>  [81,]   0.04519257         2.4
#>  [82,]   0.04519257         2.4
#>  [83,]   0.51270196         2.7
#>  [84,]   0.82437488         2.7
#>  [85,]  -0.11064389         3.0
#>  [86,]   0.82437488         3.4
#>  [87,]   1.91523012         3.1
#>  [88,]   1.29188427         2.3
#>  [89,]   0.20102904         3.0
#>  [90,]   0.04519257         2.5
#>  [91,]   0.04519257         2.6
#>  [92,]   0.98021135         3.0
#>  [93,]   0.51270196         2.6
#>  [94,]  -0.73398974         2.3
#>  [95,]   0.20102904         2.7
#>  [96,]   0.35686550         3.0
#>  [97,]   0.35686550         2.9
#>  [98,]   1.13604781         2.9
#>  [99,]  -0.57815327         2.5
#> [100,]   0.35686550         2.8
#> 
#> $outcomes
#> NULL
#> 
#> $extras
#> $extras$roles
#> NULL
forge(test, blueprint = processed$blueprint, outcomes = TRUE)
#> $predictors
#> 50 x 2 sparse Matrix of class "dgCMatrix"
#>       Sepal.Length Sepal.Width
#>  [1,]    1.2918843         3.3
#>  [2,]    0.5127020         2.7
#>  [3,]    2.5385760         3.0
#>  [4,]    1.2918843         2.9
#>  [5,]    1.6035572         3.0
#>  [6,]    3.3177583         3.0
#>  [7,]   -0.8898262         2.5
#>  [8,]    2.8502489         2.9
#>  [9,]    1.9152301         2.5
#> [10,]    2.6944124         3.6
#> [11,]    1.6035572         3.2
#> [12,]    1.4477207         2.7
#> [13,]    2.0710666         3.0
#> [14,]    0.3568655         2.5
#> [15,]    0.5127020         2.8
#> [16,]    1.4477207         3.2
#> [17,]    1.6035572         3.0
#> [18,]    3.4735947         3.8
#> [19,]    3.4735947         2.6
#> [20,]    0.8243749         2.2
#> [21,]    2.2269030         3.2
#> [22,]    0.2010290         2.8
#> [23,]    3.4735947         2.8
#> [24,]    1.2918843         2.7
#> [25,]    1.9152301         3.3
#> [26,]    2.6944124         3.2
#> [27,]    1.1360478         2.8
#> [28,]    0.9802113         3.0
#> [29,]    1.4477207         2.8
#> [30,]    2.6944124         3.0
#> [31,]    3.0060854         2.8
#> [32,]    3.7852677         3.8
#> [33,]    1.4477207         2.8
#> [34,]    1.2918843         2.8
#> [35,]    0.9802113         2.6
#> [36,]    3.4735947         3.0
#> [37,]    1.2918843         3.4
#> [38,]    1.4477207         3.1
#> [39,]    0.8243749         3.0
#> [40,]    2.2269030         3.1
#> [41,]    1.9152301         3.1
#> [42,]    2.2269030         3.1
#> [43,]    0.5127020         2.7
#> [44,]    2.0710666         3.2
#> [45,]    1.9152301         3.3
#> [46,]    1.9152301         3.0
#> [47,]    1.2918843         2.5
#> [48,]    1.6035572         3.0
#> [49,]    1.1360478         3.4
#> [50,]    0.6685384         3.0
#> 
#> $outcomes
#> # A tibble: 50 x 1
#>    Species  
#>    <fct>    
#>  1 virginica
#>  2 virginica
#>  3 virginica
#>  4 virginica
#>  5 virginica
#>  6 virginica
#>  7 virginica
#>  8 virginica
#>  9 virginica
#> 10 virginica
#> # … with 40 more rows
#> 
#> $extras
#> $extras$roles
#> NULL

Created on 2020-09-29 by the reprex package (v0.3.0.9001)

A few things to note:

I set this:

#' @param composition Either "tibble", "matrix", or "dgCMatrix" for the format
#' of the processed predictors.

I don't think there's much point to handling an explicit data.frame case, although it is a composition option in recipes.

topepo commented 4 years ago

This might be more of a workflows issue but .fit_pre() still generates a tibble

# See https://github.com/tidymodels/hardhat/pull/150
# remotes::install_github("tidymodels/hardhat@recipe-blueprint-composition")

# See https://github.com/tidymodels/parsnip/pull/373
# remotes::install_github("tidymodels/parsnip@sparsity")

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0          ✓ recipes   0.1.13    
#> ✓ dials     0.0.9          ✓ rsample   0.0.8     
#> ✓ dplyr     1.0.2          ✓ tibble    3.0.3     
#> ✓ ggplot2   3.3.2          ✓ tidyr     1.1.2     
#> ✓ infer     0.5.2          ✓ tune      0.1.1     
#> ✓ modeldata 0.0.2          ✓ workflows 0.2.0     
#> ✓ parsnip   0.1.3.9000     ✓ yardstick 0.0.7     
#> ✓ purrr     0.3.4
#> ── Conflicts ───────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
library(hardhat)

data(ames)

ames <- 
  ames %>% 
  mutate(Sale_Price = log10(Sale_Price)) %>% 
  select(Sale_Price, Longitude, Latitude, Neighborhood)

rec <- 
  recipe(Sale_Price ~ ., data = ames) %>% 
  step_dummy(Neighborhood) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")

processed <- mold(rec, ames, blueprint = sparse_bp)
class(forge(ames, blueprint = processed$blueprint)$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"

lm_spec <- linear_reg(mixture = .5) %>% set_engine("glmnet")

lm_wflow <- 
  workflow() %>% 
  add_recipe(rec, blueprint = sparse_bp) %>% 
  add_model(lm_spec)

lm_wflow_1 <- .fit_pre(lm_wflow, data = ames)
lm_wflow_1$pre$mold$predictors
#> # A tibble: 2,930 x 29
#>    Longitude Latitude Neighborhood_Co… Neighborhood_Ol… Neighborhood_Ed…
#>        <dbl>    <dbl>            <dbl>            <dbl>            <dbl>
#>  1     0.901    1.06            -0.317           -0.298           -0.266
#>  2     0.900    1.01            -0.317           -0.298           -0.266
#>  3     0.915    0.987           -0.317           -0.298           -0.266
#>  4     0.995    0.911           -0.317           -0.298           -0.266
#>  5     0.154    1.43            -0.317           -0.298           -0.266
#>  6     0.155    1.43            -0.317           -0.298           -0.266
#>  7     0.354    1.55            -0.317           -0.298           -0.266
#>  8     0.353    1.43            -0.317           -0.298           -0.266
#>  9     0.391    1.45            -0.317           -0.298           -0.266
#> 10     0.149    1.34            -0.317           -0.298           -0.266
#> # … with 2,920 more rows, and 24 more variables: Neighborhood_Somerset <dbl>,
#> #   Neighborhood_Northridge_Heights <dbl>, Neighborhood_Gilbert <dbl>,
#> #   Neighborhood_Sawyer <dbl>, Neighborhood_Northwest_Ames <dbl>,
#> #   Neighborhood_Sawyer_West <dbl>, Neighborhood_Mitchell <dbl>,
#> #   Neighborhood_Brookside <dbl>, Neighborhood_Crawford <dbl>,
#> #   Neighborhood_Iowa_DOT_and_Rail_Road <dbl>, Neighborhood_Timberland <dbl>,
#> #   Neighborhood_Northridge <dbl>, Neighborhood_Stone_Brook <dbl>,
#> #   Neighborhood_South_and_West_of_Iowa_State_University <dbl>,
#> #   Neighborhood_Clear_Creek <dbl>, Neighborhood_Meadow_Village <dbl>,
#> #   Neighborhood_Briardale <dbl>, Neighborhood_Bloomington_Heights <dbl>,
#> #   Neighborhood_Veenker <dbl>, Neighborhood_Northpark_Villa <dbl>,
#> #   Neighborhood_Blueste <dbl>, Neighborhood_Greens <dbl>,
#> #   Neighborhood_Green_Hills <dbl>, Neighborhood_Landmark <dbl>

Created on 2020-09-30 by the reprex package (v0.3.0)

Session info ``` r devtools::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.0.2 (2020-06-22) #> os macOS Catalina 10.15.5 #> system x86_64, darwin17.0 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz America/New_York #> date 2020-09-30 #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date lib source #> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.0.0) #> backports 1.1.10 2020-09-15 [1] CRAN (R 4.0.2) #> broom * 0.7.0 2020-07-09 [1] CRAN (R 4.0.0) #> callr 3.4.4 2020-09-07 [1] CRAN (R 4.0.2) #> class 7.3-17 2020-04-26 [1] CRAN (R 4.0.2) #> cli 2.0.2 2020-02-28 [1] CRAN (R 4.0.0) #> codetools 0.2-16 2018-12-24 [1] CRAN (R 4.0.2) #> colorspace 1.4-1 2019-03-18 [1] CRAN (R 4.0.0) #> crayon 1.3.4.9000 2020-08-18 [1] Github (r-lib/crayon@6b3f0c6) #> desc 1.2.0 2018-05-01 [1] CRAN (R 4.0.0) #> devtools 2.3.1 2020-07-21 [1] CRAN (R 4.0.2) #> dials * 0.0.9 2020-09-16 [1] CRAN (R 4.0.2) #> DiceDesign 1.8-1 2019-07-31 [1] CRAN (R 4.0.0) #> digest 0.6.25 2020-02-23 [1] CRAN (R 4.0.0) #> dplyr * 1.0.2 2020-08-18 [1] CRAN (R 4.0.0) #> ellipsis 0.3.1 2020-05-15 [1] CRAN (R 4.0.0) #> evaluate 0.14 2019-05-28 [1] CRAN (R 4.0.0) #> fansi 0.4.1 2020-01-08 [1] CRAN (R 4.0.0) #> foreach 1.5.0 2020-03-30 [1] CRAN (R 4.0.2) #> fs 1.5.0 2020-07-31 [1] CRAN (R 4.0.2) #> furrr 0.1.0 2018-05-16 [1] CRAN (R 4.0.0) #> future 1.19.1 2020-09-22 [1] CRAN (R 4.0.2) #> generics 0.0.2 2018-11-29 [1] CRAN (R 4.0.0) #> ggplot2 * 3.3.2 2020-06-19 [1] CRAN (R 4.0.0) #> globals 0.13.0 2020-09-17 [1] CRAN (R 4.0.2) #> glue 1.4.2 2020-08-27 [1] CRAN (R 4.0.2) #> gower 0.2.2 2020-06-23 [1] CRAN (R 4.0.0) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.0.0) #> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.0.0) #> hardhat * 0.1.4.9000 2020-09-30 [1] Github (tidymodels/hardhat@b763b1f) #> highr 0.8 2019-03-20 [1] CRAN (R 4.0.0) #> htmltools 0.5.0 2020-06-16 [1] CRAN (R 4.0.0) #> infer * 0.5.2 2020-06-14 [1] CRAN (R 4.0.0) #> ipred 0.9-9 2019-04-28 [1] CRAN (R 4.0.2) #> iterators 1.0.12 2019-07-26 [1] CRAN (R 4.0.0) #> knitr 1.30 2020-09-22 [1] CRAN (R 4.0.2) #> lattice 0.20-41 2020-04-02 [1] CRAN (R 4.0.2) #> lava 1.6.8 2020-09-26 [1] CRAN (R 4.0.2) #> lhs 1.1.0 2020-09-29 [1] CRAN (R 4.0.2) #> lifecycle 0.2.0 2020-03-06 [1] CRAN (R 4.0.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.0.0) #> lubridate 1.7.9 2020-06-08 [1] CRAN (R 4.0.2) #> magrittr 1.5 2014-11-22 [1] CRAN (R 4.0.0) #> MASS 7.3-51.6 2020-04-26 [1] CRAN (R 4.0.2) #> Matrix 1.2-18 2019-11-27 [1] CRAN (R 4.0.2) #> memoise 1.1.0 2017-04-21 [1] CRAN (R 4.0.0) #> modeldata * 0.0.2 2020-06-22 [1] CRAN (R 4.0.2) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.0.0) #> nnet 7.3-14 2020-04-26 [1] CRAN (R 4.0.2) #> parsnip * 0.1.3.9000 2020-09-30 [1] Github (tidymodels/parsnip@659b5ad) #> pillar 1.4.6 2020-07-10 [1] CRAN (R 4.0.0) #> pkgbuild 1.1.0 2020-07-13 [1] CRAN (R 4.0.2) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.0.0) #> pkgload 1.1.0 2020-05-29 [1] CRAN (R 4.0.0) #> plyr 1.8.6 2020-03-03 [1] CRAN (R 4.0.2) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.0.0) #> pROC 1.16.2 2020-03-19 [1] CRAN (R 4.0.2) #> processx 3.4.4 2020-09-03 [1] CRAN (R 4.0.2) #> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.0.0) #> ps 1.3.4 2020-08-11 [1] CRAN (R 4.0.2) #> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.0.0) #> R6 2.4.1 2019-11-12 [1] CRAN (R 4.0.0) #> Rcpp 1.0.5 2020-07-06 [1] CRAN (R 4.0.0) #> recipes * 0.1.13 2020-06-23 [1] CRAN (R 4.0.2) #> remotes 2.2.0 2020-07-21 [1] CRAN (R 4.0.2) #> rlang 0.4.7 2020-07-09 [1] CRAN (R 4.0.0) #> rmarkdown 2.3 2020-06-18 [1] CRAN (R 4.0.2) #> rpart 4.1-15 2019-04-12 [1] CRAN (R 4.0.2) #> rprojroot 1.3-2 2018-01-03 [1] CRAN (R 4.0.0) #> rsample * 0.0.8 2020-09-23 [1] CRAN (R 4.0.2) #> rstudioapi 0.11 2020-02-07 [1] CRAN (R 4.0.0) #> scales * 1.1.1 2020-05-11 [1] CRAN (R 4.0.2) #> sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 4.0.2) #> stringi 1.5.3 2020-09-09 [1] CRAN (R 4.0.2) #> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.0.0) #> survival 3.1-12 2020-04-10 [1] CRAN (R 4.0.2) #> testthat 2.3.2 2020-03-02 [1] CRAN (R 4.0.2) #> tibble * 3.0.3 2020-07-10 [1] CRAN (R 4.0.0) #> tidymodels * 0.1.1 2020-07-14 [1] CRAN (R 4.0.0) #> tidyr * 1.1.2 2020-08-27 [1] CRAN (R 4.0.2) #> tidyselect 1.1.0 2020-05-11 [1] CRAN (R 4.0.0) #> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.0.0) #> tune * 0.1.1 2020-07-08 [1] CRAN (R 4.0.2) #> usethis 1.9.0.9000 2020-09-30 [1] Github (r-lib/usethis@b993e83) #> utf8 1.1.4 2018-05-24 [1] CRAN (R 4.0.0) #> vctrs 0.3.4 2020-08-29 [1] CRAN (R 4.0.2) #> withr 2.3.0 2020-09-22 [1] CRAN (R 4.0.2) #> workflows * 0.2.0 2020-09-15 [1] CRAN (R 4.0.2) #> xfun 0.17 2020-09-09 [1] CRAN (R 4.0.2) #> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.0.0) #> yardstick * 0.0.7 2020-07-13 [1] CRAN (R 4.0.2) #> #> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library ```
juliasilge commented 4 years ago

OK, I think this is getting there! 👍 I now have mold() working as well, and also wired up to the important various bits:

library(tidymodels)
library(hardhat)

data(ames, package = "modeldata")
x <- ames %>%
  select(Longitude, Latitude, Year_Built)
y <- log10(ames$Sale_Price)

## works for xy
bp1 <- default_xy_blueprint(composition = "dgCMatrix")
x1 <- mold(x, y, blueprint = bp1)
class(x1$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
colnames(x1$predictors)
#> [1] "Longitude"  "Latitude"   "Year_Built"

## works for formula
bp2 <- default_formula_blueprint(composition = "matrix")
x2 <- mold(log10(Sale_Price) ~ Longitude + Latitude + Neighborhood, ames, blueprint = bp2)
class(x2$predictors)
#> [1] "matrix" "array"
colnames(x2$predictors)
#>  [1] "Longitude"                                          
#>  [2] "Latitude"                                           
#>  [3] "NeighborhoodNorth_Ames"                             
#>  [4] "NeighborhoodCollege_Creek"                          
#>  [5] "NeighborhoodOld_Town"                               
#>  [6] "NeighborhoodEdwards"                                
#>  [7] "NeighborhoodSomerset"                               
#>  [8] "NeighborhoodNorthridge_Heights"                     
#>  [9] "NeighborhoodGilbert"                                
#> [10] "NeighborhoodSawyer"                                 
#> [11] "NeighborhoodNorthwest_Ames"                         
#> [12] "NeighborhoodSawyer_West"                            
#> [13] "NeighborhoodMitchell"                               
#> [14] "NeighborhoodBrookside"                              
#> [15] "NeighborhoodCrawford"                               
#> [16] "NeighborhoodIowa_DOT_and_Rail_Road"                 
#> [17] "NeighborhoodTimberland"                             
#> [18] "NeighborhoodNorthridge"                             
#> [19] "NeighborhoodStone_Brook"                            
#> [20] "NeighborhoodSouth_and_West_of_Iowa_State_University"
#> [21] "NeighborhoodClear_Creek"                            
#> [22] "NeighborhoodMeadow_Village"                         
#> [23] "NeighborhoodBriardale"                              
#> [24] "NeighborhoodBloomington_Heights"                    
#> [25] "NeighborhoodVeenker"                                
#> [26] "NeighborhoodNorthpark_Villa"                        
#> [27] "NeighborhoodBlueste"                                
#> [28] "NeighborhoodGreens"                                 
#> [29] "NeighborhoodGreen_Hills"                            
#> [30] "NeighborhoodLandmark"                               
#> [31] "NeighborhoodHayden_Lake"

## can forge
xx1 <- forge(ames, blueprint = x1$blueprint)
class(xx1$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"

## workflow can get to the "recomposed" data
rec <- 
  recipe(Sale_Price ~  Longitude + Latitude + Neighborhood, data = ames) %>% 
  step_dummy(Neighborhood) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

lasso_spec <- linear_reg(mixture = 1) %>% set_engine("glmnet")
bp3 <- default_recipe_blueprint(composition = "dgCMatrix")

lasso_wflow <- 
  workflow() %>% 
  add_recipe(rec, blueprint = bp3) %>% 
  add_model(lasso_spec)

wflow_1 <- .fit_pre(lasso_wflow, data = ames)
class(wflow_1$pre$mold$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
colnames(wflow_1$pre$mold$predictors)
#>  [1] "Longitude"                                           
#>  [2] "Latitude"                                            
#>  [3] "Neighborhood_College_Creek"                          
#>  [4] "Neighborhood_Old_Town"                               
#>  [5] "Neighborhood_Edwards"                                
#>  [6] "Neighborhood_Somerset"                               
#>  [7] "Neighborhood_Northridge_Heights"                     
#>  [8] "Neighborhood_Gilbert"                                
#>  [9] "Neighborhood_Sawyer"                                 
#> [10] "Neighborhood_Northwest_Ames"                         
#> [11] "Neighborhood_Sawyer_West"                            
#> [12] "Neighborhood_Mitchell"                               
#> [13] "Neighborhood_Brookside"                              
#> [14] "Neighborhood_Crawford"                               
#> [15] "Neighborhood_Iowa_DOT_and_Rail_Road"                 
#> [16] "Neighborhood_Timberland"                             
#> [17] "Neighborhood_Northridge"                             
#> [18] "Neighborhood_Stone_Brook"                            
#> [19] "Neighborhood_South_and_West_of_Iowa_State_University"
#> [20] "Neighborhood_Clear_Creek"                            
#> [21] "Neighborhood_Meadow_Village"                         
#> [22] "Neighborhood_Briardale"                              
#> [23] "Neighborhood_Bloomington_Heights"                    
#> [24] "Neighborhood_Veenker"                                
#> [25] "Neighborhood_Northpark_Villa"                        
#> [26] "Neighborhood_Blueste"                                
#> [27] "Neighborhood_Greens"                                 
#> [28] "Neighborhood_Green_Hills"                            
#> [29] "Neighborhood_Landmark"

Created on 2020-09-30 by the reprex package (v0.3.0.9001)

I added significant changes to the tests for the preprocessors FYI.

github-actions[bot] commented 3 years ago

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.