tidymodels / hardhat

Construct Modeling Packages
https://hardhat.tidymodels.org
Other
101 stars 16 forks source link

workflows predict.workflow, hardhat forge and Error: Can't subset columns that don't exist. when recipe includes outcome variables. Cross-post #149

Closed yogat3ch closed 4 years ago

yogat3ch commented 4 years ago

Crosspost from workflows#60

Hi folks, It looks like the predict method for workflow objects (workflows:::predict.workflows) seems to remove the outcome variables in hardhat::forge, and thus causes an error if a recipe has a pre-processing step that involves the outcome variables. I'm cross-posting this to hardhat since I think the issue may be in the way variables are passed inside of hardhat:::forge.data.frame.

Here's a reprex:

test <- setNames(as.data.frame(matrix(rnorm(1000), ncol = 10)), c("y", paste0("x", 1:9)))
rec <- recipes::recipe(y ~ ., data = test) %>% 
  recipes::step_normalize(recipes::all_outcomes())
xgboost = workflows::workflow() %>% 
  workflows::add_recipe(rec) %>% 
  workflows::add_model(
    parsnip::boost_tree(
      mode = "regression",
      mtry = 8,
      trees = 12,
      min_n = 9,
      tree_depth = 8,
      learn_rate = 0.059,
      loss_reduction = 0.0001
    ) %>% 
      parsnip::set_engine("xgboost")
  )         
  fitted <- parsnip::fit(xgboost, test)
  predict(fitted, new_data = test)

Here's the stack trace.

1. +-stats::predict(.out, new_data = .d$test)
  2. +-workflows:::predict.workflow(.out, new_data = .d$test)
  3. | +-hardhat::forge(new_data, blueprint)
  4. | \-hardhat:::forge.data.frame(new_data, blueprint)
  5. |   \-blueprint$forge$process(...)
  6. |     +-recipes::bake(object = rec, new_data = new_data)
  7. |     \-recipes:::bake.recipe(object = rec, new_data = new_data)
  8. |       +-recipes::bake(object$steps[[i]], new_data = new_data)
  9. |       \-recipes:::bake.step_normalize(object$steps[[i]], new_data = new_data)
 10. |         +-base::sweep(...)
 11. |         +-base::as.matrix(new_data[, names(object$means)])
 12. |         +-new_data[, names(object$means)]
 13. |         \-tibble:::`[.tbl_df`(new_data, , names(object$means))
 14. |           \-tibble:::vectbl_as_col_location(...)
 15. |             +-tibble:::subclass_col_index_errors(...)
 16. |             | \-base::withCallingHandlers(...)
 17. |             \-vctrs::vec_as_location(j, n, names)
 18. \-vctrs:::stop_subscript_oob(...)
 19.   \-vctrs:::stop_subscript(...)

hardhat::forge is not passed outcomes predict.R#L59 and thus hardhat:::forge.data.frame has outcomes as false. When blueprint$forge$clean method is called , outcome is returned as NULL even though the blueprint$recipe$term_info object still has the outcome variable indicated. When blueprint$forge$process is called and new_data is transformed using

new_data <- vctrs::vec_cbind(predictors, outcomes, !!!unname(extras$roles), 
    .name_repair = "minimal")

hardhat::blueprint-recipe-default.R#L341 new_data no longer contains the outcome variable. A few steps later when new_data reaches the bake step, if the recipe involves pre-processing ofoutcome variables, the following error is thrown: Error: Can't subset columns that don't exist.

That's the extent of what I can make of it from debugging.

Hope it's an easy fix!

sessionInfo:

R version 3.5.3 (2019-03-11)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 18362)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252 
[2] LC_CTYPE=English_United States.1252   
[3] LC_MONETARY=English_United States.1252
[4] LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets 
[6] methods   base     

other attached packages:
 [1] yardstick_0.0.6      tibble_3.0.3.9000   
 [3] rsample_0.0.7        recipes_0.1.10      
 [5] purrr_0.3.4.9000     parsnip_0.1.2       
 [7] infer_0.5.2          ggplot2_3.3.2       
 [9] dials_0.0.8          scales_1.1.1.9000   
[11] broom_0.7.0          RevoUtils_11.0.3    
[13] AlpacaforR_1.0.0     printr_0.1          
[15] dplyr_1.0.0          magrittr_1.5        
[17] RevoUtilsMath_11.0.0

loaded via a namespace (and not attached):
  [1] backports_1.1.8        tidytext_0.2.5        
  [3] workflows_0.1.2        plyr_1.8.6            
  [5] igraph_1.2.5           splines_3.5.3         
  [7] websocket_1.1.0        crosstalk_1.1.0.1     
  [9] listenv_0.8.0          SnowballC_0.7.0       
 [11] rstantools_2.1.1       inline_0.3.15         
 [13] digest_0.6.25.1        foreach_1.5.0         
 [15] htmltools_0.5.0.9000   rsconnect_0.8.16      
 [17] fansi_0.4.1            memoise_1.1.0         
 [19] globals_0.12.5         gower_0.2.2           
 [21] RcppParallel_5.0.2     doFuture_0.9.0        
 [23] matrixStats_0.56.0     xts_0.12-0            
 [25] hardhat_0.1.4          anytime_0.3.7         
 [27] prettyunits_1.1.1      colorspace_1.4-1      
 [29] blob_1.2.1             xfun_0.15.1           
 [31] callr_3.4.3            crayon_1.3.4          
 [33] jsonlite_1.7.0         lme4_1.1-23           
 [35] survival_2.43-3        zoo_1.8-8             
 [37] iterators_1.0.12       glue_1.4.1.9000       
 [39] gtable_0.3.0           ipred_0.9-9           
 [41] pkgbuild_1.0.8         rstan_2.19.3          
 [43] DBI_1.1.0              miniUI_0.1.1.1        
 [45] Rcpp_1.0.5.1           xtable_1.8-4          
 [47] GPfit_1.0-8            bit_1.1-15.2          
 [49] DT_0.14.1              stats4_3.5.3          
 [51] lava_1.6.7             StanHeaders_2.21.0-5  
 [53] prodlim_2019.11.13     htmlwidgets_1.5.1.9001
 [55] httr_1.4.1             threejs_0.3.3         
 [57] ellipsis_0.3.1         pkgconfig_2.0.3       
 [59] loo_2.3.0              nnet_7.3-12           
 [61] dbplyr_1.4.2           utf8_1.1.4            
 [63] reshape2_1.4.4         tidyselect_1.1.0      
 [65] rlang_0.4.7            DiceDesign_1.8-1      
 [67] later_1.1.0.9000       munsell_0.5.0         
 [69] tools_3.5.3            xgboost_0.90.0.2      
 [71] cli_2.0.2              generics_0.0.2        
 [73] RSQLite_2.2.0.9000     ggridges_0.5.2        
 [75] stringr_1.4.0          fastmap_1.0.1         
 [77] yaml_2.2.1             processx_3.4.3        
 [79] knitr_1.29.3           bit64_0.9-7           
 [81] fs_1.4.1.9000          packrat_0.5.0         
 [83] nlme_3.1-137           future_1.18.0         
 [85] mime_0.9               catchr_0.2.2          
 [87] rstanarm_2.19.3        tictoc_1.0            
 [89] tokenizers_0.2.1       shinythemes_1.1.2     
 [91] compiler_3.5.3         bayesplot_1.7.2       
 [93] rstudioapi_0.11        testthat_2.3.2        
 [95] tidyposterior_0.0.3    statmod_1.4.34        
 [97] lhs_1.0.2              stringi_1.4.7         
 [99] ps_1.3.3               desc_1.2.0            
[101] lattice_0.20-38        Matrix_1.2-15         
[103] nloptr_1.2.2.1         markdown_1.1          
[105] shinyjs_1.1            vctrs_0.3.1           
[107] tidymodels_0.0.2       pillar_1.4.6.9000     
[109] lifecycle_0.2.0.9000   furrr_0.1.0.9002      
[111] data.table_1.12.8      httpuv_1.5.4.9000     
[113] R6_2.4.1               promises_1.1.1.9000   
[115] gridExtra_2.3          janeaustenr_0.1.5     
[117] codetools_0.2-16       boot_1.3-20           
[119] gtools_3.8.2           colourpicker_1.0      
[121] MASS_7.3-51.1          assertthat_0.2.1      
[123] pkgload_1.1.0          rprojroot_1.3-2       
[125] withr_2.2.0            HDA_0.0.0.9000        
[127] shinystan_2.5.0        parallel_3.5.3        
[129] tsibble_0.9.1          grid_3.5.3            
[131] rpart_4.1-13           timeDate_3043.102     
[133] minqa_1.2.4            tidyr_1.1.0           
[135] class_7.3-15           pROC_1.16.2           
[137] tidypredict_0.4.5      shiny_1.5.0.9000      
[139] lubridate_1.7.9        base64enc_0.1-3       
[141] dygraphs_1.1.1.6      
github-actions[bot] commented 3 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.