rstudio / tfdatasets

R interface to TensorFlow Datasets API
https://tensorflow.rstudio.com/tools/tfdatasets/
34 stars 12 forks source link

can't save/restore models that use R6 classes that inherit from `Step` #87

Open t-kalinowski opened 3 years ago

t-kalinowski commented 3 years ago

https://github.com/rstudio/keras/issues/1231

t-kalinowski commented 3 years ago
library(tensorflow)
library(tfdatasets)
library(keras)

#create data frame / training data:
a = rnorm(2000, 1.2, sd = 5)
b = rnorm(2000, 1, sd = 2)

cc = sample(LETTERS, size = 2000, replace = TRUE)
d = sample(c(TRUE,FALSE), size = 2000, replace = TRUE)

int_dummy <- seq(1,10, by = 1)
e = sample(int_dummy, size = 2000, replace = TRUE)
e <- as.integer(e)

y = sample(c(0,1), size = 2000, replace = TRUE)

ds <- tibble::lst(a, b, cc, d, e, y) %>%
  tensor_slices_dataset() %>%
  dataset_batch(32)

spec <- feature_spec(ds, y ~ .) %>%
  step_numeric_column(d) %>%
  step_numeric_column(c(a,b)) %>%
  # step_numeric_column(c(a,b), normalizer_fn = scaler_standard()) %>%
  # ## use of scalar_standard() causes model saving/restuoring to fail here
  step_categorical_column_with_vocabulary_list(cc, vocabulary_list = LETTERS) %>%
  step_indicator_column(cc) %>%
  step_categorical_column_with_identity(e, num_buckets = 11, default_value = 10) %>%
  step_indicator_column(e)

spec_prep <- fit(spec)

ds <- dataset_use_spec(ds, spec_prep)

input <- layer_input_from_dataset(ds)

output <- input %>%
  layer_dense_features(dense_features(spec_prep)) %>%
  layer_dense(units = 32, activation = "relu") %>%
  layer_dense(units = 1, activation = "sigmoid")

model <- keras_model(input, output)

model %>% compile(
  loss = loss_binary_crossentropy(),
  optimizer = "rmsprop",
  metrics = "binary_accuracy"
)

history <- model %>% fit(ds, epochs = 1)

save_model_tf(object = model, filepath = "tf_models/1")
new_m <- load_model_tf("tf_models/1/") # Error if `normalizer_fn = scalar_standard()` above

Fixing this will require adding a get_config() method to Step, and likely each of the R6 sub classes as well. This is tricky, because the R6 object is not passed to the python side ever, but rather, an R closure that is produced on demand with Step$fun(). That means that we'll have to start passing python objects that wrap the R clousure in an object that exposes the R closure in __call__ and also offer a get_config() method.

Once get_config() works, from_config() will also require changes. None of of the Step classes currently can be instantiate as fitted objects already, so we'd need to redefine the initializer for each of the R6 classes that allows for bypassing a fit() call.

This will also likely required exporting a new symbol that can be passed to load_model_tf(..., custom_objects = ).