mlr-org / mlr3

mlr3: Machine Learning in R - next generation
https://mlr3.mlr-org.com
GNU Lesser General Public License v3.0
927 stars 86 forks source link

[Suggestion] Serialization property for learners #891

Closed sebffischer closed 1 week ago

sebffischer commented 1 year ago

Some learners (especiall torch learners) but also e.g. lightgbm break when saved and reloaded. One situation where this occurs is if one calls benchmark(..., store_models = TRUE).

To avoid this / make it more comfortable for the user, I suggest adding a property "serialize". If this property is present, a learner must implement a public method serialize() that converts the learners $state into a serialized state.

To implement that, we could save the previous $state of a learner in a private field $.state and make $state an active binding that unserializes a earner's state if it is accessed and serialized.

This allows us to hide the serialization from the user in some circumstances e.g. in the benchmark() function we can call learner$serialize() if store_models is TRUE. The user can then afterwards access the learner and does not have to call learner$unserialize() because this will automatically happen when he accesses the state.

E.g. using bundle, this might look as follows for LightGBM:

LearnerClassifLightGBM = R6Class("LearnerClassifLightGBM",
  ...,
  public = list(
    serialize = function() {
      private$.state = bundle(private.$state)
      private$.serialized = TRUE
    }
  ),
  private = list(
    .state = NULL,
    .serialized = FALSE
  ),
  active = list(
    state = function(x) {
      if (missing(x)) {
        if (private$.serialized) {
          private$.state = unbundle(private$.state)
          private$.serialized = FALSE
        }
      } else {
        private$.state = x
      }
      return(private$.state)
    }
  )
)

In addition to that, it might be convenient to offer a $save(path) method that calls $serialize() and then saveRDS()

sebffischer commented 1 year ago

Note that there is also an open issue in bundle: https://github.com/rstudio/bundle/issues/13

sebffischer commented 1 year ago

There is also this related package: https://github.com/HenrikBengtsson/marshal

sebffischer commented 1 week ago

this is implemented