Closed jdberson closed 2 weeks ago
We currently don't supply any dbarts-specific functionality in bundle, but looks like dbart objects have an extptr that we lose access to in new sessions. Strange to me that the model will predict without it anyway.
library(tidymodels)
library(callr)
library(waldo)
# Data
data("two_class_dat", package = "modeldata")
mod_bart <- bart(mode = "classification", engine = "dbarts")
# fit two models, each with same seed
set.seed(1)
bart_fit_1 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])
set.seed(1)
bart_fit_2 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])
# some differences, but fine as long as the predict the same
compare(bart_fit_1, bart_fit_2)
#> `old$fit$fit@.xData$.->pointer` is <pointer: 0x12be4c320>
#> `new$fit$fit@.xData$.->pointer` is <pointer: 0x11bee3570>
#>
#> `old$fit$fit@.xData$pointer` is <pointer: 0x12be4c320>
#> `new$fit$fit@.xData$pointer` is <pointer: 0x11bee3570>
#>
#> `attr(old$fit$fit@.xData$state, 'runningTime')`: 0.397
#> `attr(new$fit$fit@.xData$state, 'runningTime')`: 0.403
#>
#> `old$elapsed$elapsed`: 0.44 0.01 0.45 0.00 0.00
#> `new$elapsed$elapsed`: 0.40 0.01 0.41 0.00 0.00
# ...which they do
set.seed(1)
p_fit_1 <- predict(bart_fit_1, two_class_dat[71:200, ], type = "prob")
set.seed(1)
p_fit_2 <- predict(bart_fit_2, two_class_dat[71:200, ], type = "prob")
compare(p_fit_1, p_fit_2)
#> ✔ No differences
# that pointer no longer exists in a new session
r(
function(model_object) {
model_object$fit$fit@.xData$pointer
},
args = list(
model_object = bart_fit_1
)
)
#> <pointer: 0x0>
Created on 2024-07-09 with reprex v2.1.0
Okay, from dbarts' docs:
Saving: saveing and loading fitted BART objects for use with predict requires that R’s serialization mechanism be able to access the underlying trees, in addition to being fit with keeptrees/keepTrees as TRUE. For memory purposes, the trees are not stored as R objects unless specifically requested. To do this, one must “touch” the sampler’s state object before saving, e.g. for a fitted object bartFit, execute invisible(bartFit$fit$state).
"Touch"ing the $fit$state
slot is easy for bundle to do. Setting keeptrees = TRUE
in the model fit is outside of bundle's scope, though parsnip sets keeptrees = TRUE
by default. Maybe the bundle()
method in this case would error if keeptrees = FALSE
?
library(tidymodels)
library(callr)
library(waldo)
# Data
data("two_class_dat", package = "modeldata")
mod_bart <- parsnip::bart(mode = "classification", engine = "dbarts")
# fit the model
set.seed(1)
bart_fit_1 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])
set.seed(1)
p_orig <- predict(bart_fit_1, two_class_dat[71:200, ], type = "prob")
invisible(bart_fit_1$fit$fit$state)
p_new_session <- r(
function(model_object, two_class_dat) {
library(parsnip)
set.seed(1)
predict(model_object, two_class_dat[71:200, ], type = "prob")
},
args = list(
model_object = bart_fit_1,
two_class_dat = two_class_dat
)
)
compare(p_orig, p_new_session)
#> ✔ No differences
Created on 2024-07-09 with reprex v2.1.0
Thanks very much @simonpcouch for looking into this and responding so quickly! It will be great if bundle could include dbarts functionality. In the meantime touching the $fit$state slot solves my problem, apologies I wasn't aware of this.
Cheers
Jacob
apologies I wasn't aware of this.
No worries at all! It's bundle's job to iron out these oddities for users, so we definitely ought to support this. I appreciate you pointing this one out.
Hello
I’m having trouble making probability predictions in a new R session from a bundled bart classification model object. Not sure if this is a bug or I am doing something silly.
Predicting from an unbundled bart classification model in the same R session works, but when I use saveRDS + readRDS, the new predictions are incorrect and tend to be clustered around 0.5.
I have included a reprex below to show the problem.
Thanks for all the great packages and for your help!
Warm regards
Jacob
Created on 2024-07-09 with reprex v2.1.0.9000