stan-dev / posterior

The posterior R package
https://mc-stan.org/posterior/
Other
167 stars 24 forks source link

Efficient saving/reading of data frames containing rvars? #307

Open kthayashi opened 1 year ago

kthayashi commented 1 year ago

Thank you so much for this fantastic package, especially the rvars data type which I really enjoy using. I've recently run into a situation where I would like to save a data frame (about 700 rows by 10 variables) that contains a single 1d rvar column from one script and read it in for use in another script. When I try to save the data frame with saveRDS, however, it takes 6+ minutes to complete the operation and the resulting object, which is ~20 MB in R, is 1.5+ GB in size. While this is workable, I was wondering if there are know solutions or recommended alternatives to saveRDS when working with rvars in this way. Thank you!

mjskay commented 1 year ago

Ah, that's not good! Is the data frame you are saving a tibble?

Looking into this, it looks like an issue with the caching that rvars use so that they can be used efficiently with {vctrs} code (such as that used by tibbles). This cache contains a number of references to the same rvar object, which the serializer serializes into a bunch of copies of that object, hence the large size of the output.

For example:

library(posterior)

set.seed(1234)
df = data.frame(x = rvar_rng(rnorm, 10))
saveRDS(df, "df.rds")

# about 300 kB
file.size("df.rds") / 1024
[1] 300.3018

Same thing, but with a tibble:

library(posterior)

set.seed(1234)
df = tibble(x = rvar_rng(rnorm, 10))
saveRDS(df, "df_tibble.rds")

# about 3 MB!
file.size("df_tibble.rds") / 1024
[1] 3301.285

How can we prevent this? One option is to stick to data frames, but that is obviously unsatisfactory. One option would be to convert to a data frame before outputting. You'll also have to clear any rvar caches after conversion:

set.seed(1234)
df = tibble(x = rvar_rng(rnorm, 10))

# convert to a data frame and clear rvar caches
df = as.data.frame(df)
rvar_i = sapply(df, is_rvar)
df[, rvar_i] = lapply(df[, rvar_i, drop = FALSE], posterior:::invalidate_rvar_cache)
# to avoid using the internal invalidate_rvar_cache function, you could
# also just apply an operation to the rvar that does nothing, like adding 0; e.g.:
# df[, rvar_i] = lapply(df[, rvar_i, drop = FALSE], \(x) x + 0)

saveRDS(df, "df_tibble.rds")

# about 300 kB again
file.size("df_tibble.rds") / 1024
[1] 300.3018

The final way is a bit more complicated, but does allow you to keep using tibbles. We can use the refhook argument to saveRDS and readRDS to make it so that the cache environments inside rvars are not saved out (this has no impact on rvar usage, as the rvar will regenerate the cache values automatically). Something like this should work:

set.seed(1234)
df = tibble(x = rvar_rng(rnorm, 10))

saveRDS(df, "df_tibble.rds", refhook = \(x) if (any(c("vec_proxy", "vec_proxy_equal") %in% names(x))) "")
# about 300 kB again
file.size("df_tibble.rds") / 1024
[1] 300.3076

Reading back in, we must supply a refhook as well:

new_df = readRDS("df_tibble.rds", refhook = \(x) new.env())
all.equal(df, new_df)
[1] TRUE
mjskay commented 1 year ago

I suppose we should probably document this somewhere and/or make it easier to do, e.g. by including the refhook functions above in the package.

kthayashi commented 1 year ago

The data frame in question is indeed a tibble (apologies, I forgot to clarify that upfront). Thank you so much for walking through the issue and the suggested workarounds. It looks like sticking to a data frame might best suit my needs for now, but I agree that having some sort of function in the package that helps handle this (sort of like how cmdstanr has the $save_object() method) would be very convenient!

mjskay commented 1 year ago

Hmm yeah. Might be able to do something that walks an object tree and just clears all rvar caches for saving.