idem-lab / example-greta-targets

example greta workflow using targets
3 stars 1 forks source link

greta has non exportable R objects so we can't have them in Targets...for the moment #1

Open njtierney opened 11 months ago

njtierney commented 11 months ago

The issue with using greta in targets is that greta uses reticulate to talk to Python, and when the R session ends the connection to Python objects is lost (they become NULL externalptr objects) - see https://rstudio.github.io/reticulate/articles/package.html#implementing-s3-methods

This problem is discussed in future somewhat: https://cran.r-project.org/web/packages/future/vignettes/future-4-non-exportable-objects.html and also in the targets documentation: https://books.ropensci.org/targets/targets.html#saving

However it seems that keras and torch models can be safely saved if you specify format = "keras" or format = "torch" (https://books.ropensci.org/targets/targets.html#fn4), so perhaps that is something we can look into. It seems that this uses a technique called "marshalling" https://docs.ropensci.org/targets/reference/tar_format.html#marshalling - need to read more here.

greta uses future to do parallelisation, but in a way that I don't fully understand. So I'm not sure if we can use some of those techniques to get around these issues. My thought is probably not, as that R session is active the whole time.

njtierney commented 11 months ago

Under the hood, it seems the format = "keras", and format = "torch" commands use special write and save functions from those packages.

Reading more about specifying tar_format in targets, it looks like there is a defined way to use keras and torch objects:

https://docs.ropensci.org/targets/reference/tar_format.html#ref-examples

# The following target is equivalent to the current superseded
# tar_target(name, command(), format = "keras").
# An improved version of this would supply a `convert` argument
# to handle NULL objects, which are returned by the target if it
# errors and the error argument of tar_target() is "null".
tar_target(
  name = keras_target,
  command = your_function(),
  format = tar_format(
    read = function(path) {
      keras::load_model_hdf5(path)
    },
    write = function(object, path) {
      keras::save_model_hdf5(object = object, filepath = path)
    },
    marshal = function(object) {
      keras::serialize_model(object)
    },
    unmarshal = function(object) {
      keras::unserialize_model(object)
    }
  )
)
# And the following is equivalent to the current superseded
# tar_target(name, torch::torch_tensor(seq_len(4)), format = "torch"),
# except this version has a `convert` argument to handle
# cases when `NULL` is returned (e.g. if the target errors out
# and the `error` argument is "null" in tar_target()
# or tar_option_set())
tar_target(
  name = torch_target,
  command = torch::torch_tensor(),
  format = tar_format(
    read = function(path) {
      torch::torch_load(path)
    },
    write = function(object, path) {
      torch::torch_save(obj = object, path = path)
    },
    marshal = function(object) {
      con <- rawConnection(raw(), open = "wr")
      on.exit(close(con))
      torch::torch_save(object, con)
      rawConnectionValue(con)
    },
    unmarshal = function(object) {
      con <- rawConnection(object, open = "r")
      on.exit(close(con))
      torch::torch_load(con)
    }
  )
)

So perhaps we need to use some form of those functions to save greta objects, or create some special write_greta and read_greta functions that piggyback off of the tensorflow saving but provide the metadata of how the greta object is linked to everything, and then we can carefully unpack it correctly with read_greta

goldingn commented 11 months ago

When defining greta arrays, no TF objects are created. greta array objects are linked to one another via R environments, but no pointers are involved. Pointers to python/TF objects are only created when model() is called on them, and sometimes again when mcmc() and other inference methods are run. When future is used to run models in parallel, there's a need to redefine models and pointers in the same way as described here, so there are methods in there somewhere to do that.

If you just need to reload the model and do things with it (as it looks like in https://github.com/njtierney/example-greta-targets/blob/main/r-script-only/attempt-explode-method.R), then redefining the TF objects might be all that's needed. E.g. using the methods of the dag object, inside the model object.

However it sounds like there are other issues with reloading and manipulating the the greta array objects too (from something Saras said) so that might be another problem and need another issue

smwindecker commented 11 months ago

If I understand @goldingn comment, yes my example repo looks a tiny bit different because at least at the moment I'm hoping to have separate targets for the greta arrays. A short version is here for reference: https://github.com/idem-lab/targets-pkg-greta, where the _targets.R file is created using the build_pipeline().

smwindecker commented 11 months ago

note in this version you can run targets::tar_make() and it'll only fail at the last target (the draws), so you can later try

targets::tar_load(m) plot(m)

and it successfully makes the dag.

After loading m and greta_arrays into the environment, tar_make() will still not make the draws (which makes sense because it's not looking in the env). But it also won't let you run mcmc(m) in the console with those objects in the environment.

goldingn commented 11 months ago

Thanks @smwindecker, that helped as a reprex. As I thought, the issue is just redefining the pointers in the new environment, like we already do when in parallel: https://github.com/greta-dev/greta/blob/tf2-poke-tf-fun/R/inference_class.R#L346-L354

So in that repo, if I run the following:

targets::tar_make(m)
targets::tar_load(m)
library(greta)
# this plots
plot(m)
# but can't sample
draws <- mcmc(m, n_samples = 1000, chains = 4)

I get

Error: Unable to access object (object is from previous session and is now invalid)

But in a new session, if we manually redefine the objects needed, like in the code for parallel new environments, it works:

targets::tar_make(m)
targets::tar_load(m)
library(greta)

# there is no log prob function defined (dead pointer) and a few other things
m$dag$tf_log_prob_function

# force rebuild these things, like we do for parallel things
dag <- m$dag
dag$define_tf_trace_values_batch()
dag$define_tf_log_prob_function()

# now it samples
draws <- mcmc(m, n_samples = 1000, chains = 4)
running 4 chains simultaneously on up to 8 CPU cores

   warmup ====================================== 1000/1000 | eta:  0s          
 sampling ====================================== 1000/1000 | eta:  0s  

So instead of explicitly recreating those pointers when in parallel, why not add a step to check for dead pointers and redefine them, whenever any of those TF objects are accessed? The we can delete the parallel-specific code, and this should work in targets.

njtierney commented 11 months ago

Amazing, thanks so much, @goldingn and @smwindecker !

When you say:

So instead of explicitly recreating those pointers when in parallel, why not add a step to check for dead pointers and redefine them, whenever any of those TF objects are accessed? The we can delete the parallel-specific code, and this should work in targets.

Would you like this to be something baked into greta, or as a step in the targets pipeline? Overall I think it would be great to have it baked into greta, potentially with some explicit messages and guard rails built in. For example, you could set a global option to turn this from a message to a warning to an error if it detects dead pointers in greta code, just in case this is undesirable behaviour.

goldingn commented 11 months ago

It would have to be baked in to greta to work I think.

I don't think we'd need an error or warning at all. Just whenever greta needs to evaluate one of these pointers (there are only a couple now), just add a little wrapper function to:

  1. see if the object exists and if it does that it isn't a dead pointer
  2. If the pointer isn't there, execute the code to create it
  3. Check the pointer now exists, and error if not.

So for the log prob function, before it's called we'd call something like: self$maybe_define_tf_log_prob_function() before doing: self$tf_log_prob_function(parameters)

Or even better, wrap those into another function: self$evaluate_log_prob_function(parameters) which does both of those steps.

We might need the same pattern for the handful of other tf pointers that are only defined once then reused (maybe the internal MCMC draws object when using extra_samples()?). But we could also delete code in other places that defines these pointers