rstudio / bundle

Prepare objects for serialization with a consistent interface
https://rstudio.github.io/bundle/
Other
25 stars 4 forks source link

predictions from a bart classification model #64

Closed jdberson closed 2 weeks ago

jdberson commented 2 weeks ago

Hello

I’m having trouble making probability predictions in a new R session from a bundled bart classification model object. Not sure if this is a bug or I am doing something silly.

Predicting from an unbundled bart classification model in the same R session works, but when I use saveRDS + readRDS, the new predictions are incorrect and tend to be clustered around 0.5.

I have included a reprex below to show the problem.

Thanks for all the great packages and for your help!

Warm regards

Jacob

# Example to show that bundling a bart classification model object and then 
# predicting the probabilities from this object in a new R session results in 
# different probability predictions that are clustered around 0.5.

## Setup

library(tidymodels)
library(bundle)
library(xgboost, quietly = TRUE, warn.conflicts = FALSE)
library(dbarts, quietly = TRUE, warn.conflicts = FALSE)
library(callr)
library(waldo)
library(stringr, quietly = TRUE, warn.conflicts = FALSE)

tidymodels_prefer()

# Data
data("two_class_dat", package = "modeldata")

set.seed(1)

## Fit model and make predictions

# fit the model
mod_bart <-
  parsnip::bart() |>
  set_mode("classification") |>
  set_engine("dbarts") |>
  fit(Class ~ ., data = two_class_dat[1:70, ])

# bundle the model
mod_bart_bundle <- bundle(mod_bart)

# Make predictions using the bundled model in the existing R session
mod_bart_unbundled <- unbundle(mod_bart_bundle)
bart_predictions_existing_session <- 
  predict(mod_bart_unbundled, two_class_dat[71:200, ], type = "prob")

# Make predictions using bundled model in a new R session
bart_predictions_new_session <-
  r(
    function(model_bundle, new_data) {
      library(bundle)
      library(parsnip)

      model_object <- unbundle(model_bundle)

      predict(model_object, new_data, type = "prob")
    },
    args = list(
      model_bundle = mod_bart_bundle,
      new_data = two_class_dat[71:200, ]
    )
  )

# Compare the predictions
compare(bart_predictions_existing_session, bart_predictions_new_session)
#> old vs new
#>              .pred_Class1 .pred_Class2
#> - old[1, ]          0.844        0.156
#> + new[1, ]          0.477        0.523
#> - old[2, ]          0.581        0.419
#> + new[2, ]          0.486        0.514
#> - old[3, ]          0.396        0.604
#> + new[3, ]          0.488        0.512
#> - old[4, ]          0.722        0.278
#> + new[4, ]          0.514        0.486
#> - old[5, ]          0.596        0.404
#> + new[5, ]          0.501        0.499
#> - old[6, ]          0.545        0.455
#> + new[6, ]          0.482        0.518
#> - old[7, ]          0.860        0.140
#> + new[7, ]          0.487        0.513
#> - old[8, ]          0.635        0.365
#> + new[8, ]          0.509        0.491
#> - old[9, ]          0.195        0.805
#> + new[9, ]          0.521        0.479
#> - old[10, ]         0.399        0.601
#> + new[10, ]         0.503        0.497
#> and 120 more ...
#> 
#>      old$.pred_Class1 | new$.pred_Class1                 
#>  [1] 0.844            - 0.477            [1]             
#>  [2] 0.581            - 0.486            [2]             
#>  [3] 0.396            - 0.488            [3]             
#>  [4] 0.722            - 0.514            [4]             
#>  [5] 0.596            - 0.501            [5]             
#>  [6] 0.545            - 0.482            [6]             
#>  [7] 0.860            - 0.487            [7]             
#>  [8] 0.635            - 0.509            [8]             
#>  [9] 0.195            - 0.521            [9]             
#> [10] 0.399            - 0.503            [10]            
#>  ... ...                ...              and 120 more ...
#> 
#>      old$.pred_Class2 | new$.pred_Class2                 
#>  [1] 0.156            - 0.523            [1]             
#>  [2] 0.419            - 0.514            [2]             
#>  [3] 0.604            - 0.512            [3]             
#>  [4] 0.278            - 0.486            [4]             
#>  [5] 0.404            - 0.499            [5]             
#>  [6] 0.455            - 0.518            [6]             
#>  [7] 0.140            - 0.513            [7]             
#>  [8] 0.365            - 0.491            [8]             
#>  [9] 0.805            - 0.479            [9]             
#> [10] 0.601            - 0.497            [10]            
#>  ... ...                ...              and 120 more ...

# Plot the  class 1 probability predictions from the new R session against the
# class 1 probability predictions from the existing R session
bind_cols(
  bart_predictions_existing_session |>
      rename_with(.fn = \(x) str_c(x, "_existing")),

    bart_predictions_new_session |>
      rename_with(.fn = \(x) str_c(x, "_new"))
  ) |>

  ggplot(aes(x = .pred_Class1_existing, .pred_Class1_new)) +
  geom_point() +
  theme_bw()

Created on 2024-07-09 with reprex v2.1.0.9000

simonpcouch commented 2 weeks ago

We currently don't supply any dbarts-specific functionality in bundle, but looks like dbart objects have an extptr that we lose access to in new sessions. Strange to me that the model will predict without it anyway.

library(tidymodels)
library(callr)
library(waldo)

# Data
data("two_class_dat", package = "modeldata")

mod_bart <- bart(mode = "classification", engine = "dbarts") 

# fit two models, each with same seed
set.seed(1)
bart_fit_1 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])

set.seed(1)
bart_fit_2 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])

# some differences, but fine as long as the predict the same
compare(bart_fit_1, bart_fit_2)
#> `old$fit$fit@.xData$.->pointer` is <pointer: 0x12be4c320>
#> `new$fit$fit@.xData$.->pointer` is <pointer: 0x11bee3570>
#> 
#> `old$fit$fit@.xData$pointer` is <pointer: 0x12be4c320>
#> `new$fit$fit@.xData$pointer` is <pointer: 0x11bee3570>
#> 
#> `attr(old$fit$fit@.xData$state, 'runningTime')`: 0.397
#> `attr(new$fit$fit@.xData$state, 'runningTime')`: 0.403
#> 
#> `old$elapsed$elapsed`: 0.44 0.01 0.45 0.00 0.00
#> `new$elapsed$elapsed`: 0.40 0.01 0.41 0.00 0.00
# ...which they do
set.seed(1)
p_fit_1 <- predict(bart_fit_1, two_class_dat[71:200, ], type = "prob")
set.seed(1)
p_fit_2 <- predict(bart_fit_2, two_class_dat[71:200, ], type = "prob")
compare(p_fit_1, p_fit_2)
#> ✔ No differences
# that pointer no longer exists in a new session
r(
  function(model_object) {
    model_object$fit$fit@.xData$pointer
  },
  args = list(
    model_object = bart_fit_1
  )
)
#> <pointer: 0x0>

Created on 2024-07-09 with reprex v2.1.0

Some other miscellaneous observations: My first reflex was to see if the predictions from this model depend on the state of RNG (without bundling). Looks like they do: ``` r library(tidymodels) library(waldo) # Data data("two_class_dat", package = "modeldata") set.seed(1) mod_bart <- parsnip::bart() |> set_mode("classification") |> set_engine("dbarts") |> fit(Class ~ ., data = two_class_dat[1:70, ]) compare( predict(mod_bart, two_class_dat[71:200, ], type = "prob"), predict(mod_bart, two_class_dat[71:200, ], type = "prob") ) #> old vs new #> .pred_Class1 .pred_Class2 #> - old[1, ] 0.844 0.156 #> + new[1, ] 0.830 0.170 #> - old[2, ] 0.581 0.419 #> + new[2, ] 0.570 0.430 #> - old[3, ] 0.396 0.604 #> + new[3, ] 0.377 0.623 #> - old[4, ] 0.722 0.278 #> + new[4, ] 0.733 0.267 #> - old[5, ] 0.596 0.404 #> + new[5, ] 0.605 0.395 #> - old[6, ] 0.545 0.455 #> + new[6, ] 0.522 0.478 #> old[7, ] 0.860 0.140 #> - old[8, ] 0.635 0.365 #> + new[8, ] 0.630 0.370 #> - old[9, ] 0.195 0.805 #> + new[9, ] 0.194 0.806 #> - old[10, ] 0.399 0.601 #> + new[10, ] 0.407 0.593 #> and 120 more ... #> #> old$.pred_Class1 | new$.pred_Class1 #> [1] 0.844 - 0.830 [1] #> [2] 0.581 - 0.570 [2] #> [3] 0.396 - 0.377 [3] #> [4] 0.722 - 0.733 [4] #> [5] 0.596 - 0.605 [5] #> [6] 0.545 - 0.522 [6] #> [7] 0.860 | 0.860 [7] #> [8] 0.635 - 0.630 [8] #> [9] 0.195 - 0.194 [9] #> [10] 0.399 - 0.407 [10] #> ... ... ... and 120 more ... #> #> old$.pred_Class2 | new$.pred_Class2 #> [1] 0.156 - 0.170 [1] #> [2] 0.419 - 0.430 [2] #> [3] 0.604 - 0.623 [3] #> [4] 0.278 - 0.267 [4] #> [5] 0.404 - 0.395 [5] #> [6] 0.455 - 0.478 [6] #> [7] 0.140 | 0.140 [7] #> [8] 0.365 - 0.370 [8] #> [9] 0.805 - 0.806 [9] #> [10] 0.601 - 0.593 [10] #> ... ... ... and 120 more ... ``` ``` r set.seed(1) p_1 <- predict(mod_bart, two_class_dat[71:200, ], type = "prob") set.seed(1) p_2 <- predict(mod_bart, two_class_dat[71:200, ], type = "prob") compare(p_1, p_2) #> ✔ No differences ``` Created on 2024-07-09 with [reprex v2.1.0](https://reprex.tidyverse.org) We don't supply any dbarts-specific functionality in bundle, so I also wondered if this was some side effect of bundling a `model_fit`, but this behavior is replicable without bundling. It is: ``` r library(tidymodels) library(bundle) library(callr) library(waldo) tidymodels_prefer() # Data data("two_class_dat", package = "modeldata") set.seed(1) # fit the model mod_bart <- parsnip::bart() |> set_mode("classification") |> set_engine("dbarts") |> fit(Class ~ ., data = two_class_dat[1:70, ]) set.seed(1) bart_predictions_existing_session <- predict(mod_bart, two_class_dat[71:200, ], type = "prob") # Make predictions using bundled model in a new R session bart_predictions_new_session <- r( function(model_object, new_data) { library(parsnip) set.seed(1) predict(model_object, new_data, type = "prob") }, args = list( model_object = mod_bart, new_data = two_class_dat[71:200, ] ) ) compare(bart_predictions_existing_session, bart_predictions_new_session) #> old vs new #> .pred_Class1 .pred_Class2 #> - old[1, ] 0.823 0.177 #> + new[1, ] 0.520 0.480 #> - old[2, ] 0.568 0.432 #> + new[2, ] 0.519 0.481 #> - old[3, ] 0.361 0.639 #> + new[3, ] 0.519 0.481 #> - old[4, ] 0.721 0.279 #> + new[4, ] 0.508 0.492 #> - old[5, ] 0.599 0.401 #> + new[5, ] 0.498 0.502 #> - old[6, ] 0.549 0.451 #> + new[6, ] 0.478 0.522 #> - old[7, ] 0.863 0.137 #> + new[7, ] 0.474 0.526 #> - old[8, ] 0.647 0.353 #> + new[8, ] 0.522 0.478 #> - old[9, ] 0.200 0.800 #> + new[9, ] 0.502 0.498 #> - old[10, ] 0.394 0.606 #> + new[10, ] 0.507 0.493 #> and 120 more ... #> #> old$.pred_Class1 | new$.pred_Class1 #> [1] 0.823 - 0.520 [1] #> [2] 0.568 - 0.519 [2] #> [3] 0.361 - 0.519 [3] #> [4] 0.721 - 0.508 [4] #> [5] 0.599 - 0.498 [5] #> [6] 0.549 - 0.478 [6] #> [7] 0.863 - 0.474 [7] #> [8] 0.647 - 0.522 [8] #> [9] 0.200 - 0.502 [9] #> [10] 0.394 - 0.507 [10] #> ... ... ... and 120 more ... #> #> old$.pred_Class2 | new$.pred_Class2 #> [1] 0.177 - 0.480 [1] #> [2] 0.432 - 0.481 [2] #> [3] 0.639 - 0.481 [3] #> [4] 0.279 - 0.492 [4] #> [5] 0.401 - 0.502 [5] #> [6] 0.451 - 0.522 [6] #> [7] 0.137 - 0.526 [7] #> [8] 0.353 - 0.478 [8] #> [9] 0.800 - 0.498 [9] #> [10] 0.606 - 0.493 [10] #> ... ... ... and 120 more ... ``` Created on 2024-07-09 with [reprex v2.1.0](https://reprex.tidyverse.org)
simonpcouch commented 2 weeks ago

Okay, from dbarts' docs:

Saving: saveing and loading fitted BART objects for use with predict requires that R’s serialization mechanism be able to access the underlying trees, in addition to being fit with keeptrees/keepTrees as TRUE. For memory purposes, the trees are not stored as R objects unless specifically requested. To do this, one must “touch” the sampler’s state object before saving, e.g. for a fitted object bartFit, execute invisible(bartFit$fit$state).

"Touch"ing the $fit$state slot is easy for bundle to do. Setting keeptrees = TRUE in the model fit is outside of bundle's scope, though parsnip sets keeptrees = TRUE by default. Maybe the bundle() method in this case would error if keeptrees = FALSE?

library(tidymodels)
library(callr)
library(waldo)

# Data
data("two_class_dat", package = "modeldata")

mod_bart <- parsnip::bart(mode = "classification", engine = "dbarts")

# fit the model
set.seed(1)
bart_fit_1 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])

set.seed(1)
p_orig <- predict(bart_fit_1, two_class_dat[71:200, ], type = "prob")

invisible(bart_fit_1$fit$fit$state)

p_new_session <- r(
  function(model_object, two_class_dat) {
    library(parsnip)

    set.seed(1)
    predict(model_object, two_class_dat[71:200, ], type = "prob")
  },
  args = list(
    model_object = bart_fit_1,
    two_class_dat = two_class_dat
  )
)

compare(p_orig, p_new_session)
#> ✔ No differences

Created on 2024-07-09 with reprex v2.1.0

jdberson commented 2 weeks ago

Thanks very much @simonpcouch for looking into this and responding so quickly! It will be great if bundle could include dbarts functionality. In the meantime touching the $fit$state slot solves my problem, apologies I wasn't aware of this.

Cheers

Jacob

simonpcouch commented 2 weeks ago

apologies I wasn't aware of this.

No worries at all! It's bundle's job to iron out these oddities for users, so we definitely ought to support this. I appreciate you pointing this one out.