kapelner / bartMachine

An R-Java Bayesian Additive Regression Trees implementation
MIT License
62 stars 27 forks source link

bart_machine_get_posterior() won't work with a bartMachineArr() lists restored from disk. #38

Closed bakaburg1 closed 3 years ago

bakaburg1 commented 3 years ago

Hello,

I'm generating an ensemble model with bartMachineArr() to produce a more robust posterior predictive distribution. I need to save the model for later use. When I restore the array though, only the first model will work with bart_machine_get_posterior(), while for the others I get:

 Error in check_serialization(bart_machine) : 
  This bartMachine object was loaded from an R image but was not serialized.
  Please build bartMachine using the option "serialize = TRUE" next time.

I guess the serialize argument of bartMachine doesn't get passed to the other models, or some connection is lost.

Here's the dummy code to produce the model:

n_models <- 5

model <- bartMachine(X = X, y = y, serialize = TRUE, ...)

if (n_models > 1) {
    model <- bartMachineArr(model, R = n_models)
} else {
    model <- list(model)
}

readr::write_rds(model, 'model.rds', compress = 'gz')

And to produce averaged predictive posteriors:

pred_post <- bart_machine_get_posterior(model[[1]], new_data = data)$y_hat_posterior_samples

if (n_models > 1) {
    for (i in 2:n_models) {
        pred_post <- pred_post + bart_machine_get_posterior(model[[i]], new_data = data)$y_hat_posterior_samples
    }
}

pred_post <- pred_post / n_models

My session info:

R version 4.0.5 (2021-03-31)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Big Sur 10.16

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib

locale:
[1] it_IT.UTF-8/it_IT.UTF-8/it_IT.UTF-8/C/it_IT.UTF-8/it_IT.UTF-8

attached base packages:
[1] grid      parallel  stats     graphics  grDevices utils     datasets  methods  
[9] base     

other attached packages:
 [1] knitr_1.31          tidytrees_0.2.2     pROC_1.17.0.1       magrittr_2.0.1     
 [5] pbmcapply_1.5.0     pbapply_1.4-3       readr_1.4.0         glue_1.4.2         
 [9] bartMachine_1.2.6   missForest_1.4      itertools_0.1-3     iterators_1.0.13   
[13] foreach_1.5.1       randomForest_4.6-14 bartMachineJARs_1.1 rJava_0.9-13       
[17] bayestestR_0.8.2    partykit_1.2-11     mvtnorm_1.1-1       libcoin_1.0-7      
[21] readxl_1.3.1        stringr_1.4.0       dplyr_1.0.6        

loaded via a namespace (and not attached):
 [1] pkgload_1.1.0      splines_4.0.5      Formula_1.2-4      assertthat_0.2.1  
 [5] pander_0.6.3       cellranger_1.1.0   yaml_2.2.1         remotes_2.2.0     
 [9] sessioninfo_1.1.1  pillar_1.6.0       backports_1.2.1    lattice_0.20-41   
[13] digest_0.6.27      pryr_0.1.4         checkmate_2.0.0    htmltools_0.5.1.1 
[17] Matrix_1.3-2       plyr_1.8.6         pkgconfig_2.0.3    devtools_2.3.2    
[21] magick_2.6.0       purrr_0.3.4        processx_3.5.2     tibble_3.1.1      
[25] generics_0.1.0     usethis_2.0.0      ellipsis_0.3.2     cachem_1.0.4      
[29] withr_2.4.1        cli_2.5.0          survival_3.2-10    crayon_1.4.1      
[33] memoise_2.0.0      evaluate_0.14      ps_1.6.0           fs_1.5.0          
[37] fansi_0.4.2        pkgbuild_1.2.0     rapportools_1.0    tools_4.0.5       
[41] prettyunits_1.1.1  hms_1.0.0          matrixStats_0.58.0 lifecycle_1.0.0   
[45] callr_3.5.1        compiler_4.0.5     inum_1.0-2         tinytex_0.30      
[49] rlang_0.4.11       base64enc_0.1-3    rmarkdown_2.7      testthat_3.0.1    
[53] codetools_0.2-18   DBI_1.1.1          R6_2.5.0           lubridate_1.7.10  
[57] fastmap_1.1.0      utf8_1.2.1         rprojroot_2.0.2    insight_0.12.0    
[61] desc_1.3.0         stringi_1.6.1      Rcpp_1.0.6         vctrs_0.3.8       
[65] rpart_4.1-15       tidyselect_1.1.1   xfun_0.22
bakaburg1 commented 3 years ago

UPDATE: If I build the list manually the error doesn't occur.

model <- lapply(1:n_models, function(i) {
    bartMachine(X = X, y = y, serialize = TRUE, verbose = F,  ...)
})

it's just a bit annoying since it prints serializing in order to be saved for future R sessions...done at every iteration (and I think is a bit slower).

kapelner commented 3 years ago

Thanks for the report. The bartMachineArr function is really just the wrapper you wrote in your UPDATE and it uses the bart_machine_duplicate method internally which sets serialize = FALSE. It was never intended to be saved due to size. So if you need this feature, you can use your code.