Open t-kalinowski opened 3 years ago
library(tensorflow)
library(tfdatasets)
library(keras)
#create data frame / training data:
a = rnorm(2000, 1.2, sd = 5)
b = rnorm(2000, 1, sd = 2)
cc = sample(LETTERS, size = 2000, replace = TRUE)
d = sample(c(TRUE,FALSE), size = 2000, replace = TRUE)
int_dummy <- seq(1,10, by = 1)
e = sample(int_dummy, size = 2000, replace = TRUE)
e <- as.integer(e)
y = sample(c(0,1), size = 2000, replace = TRUE)
ds <- tibble::lst(a, b, cc, d, e, y) %>%
tensor_slices_dataset() %>%
dataset_batch(32)
spec <- feature_spec(ds, y ~ .) %>%
step_numeric_column(d) %>%
step_numeric_column(c(a,b)) %>%
# step_numeric_column(c(a,b), normalizer_fn = scaler_standard()) %>%
# ## use of scalar_standard() causes model saving/restuoring to fail here
step_categorical_column_with_vocabulary_list(cc, vocabulary_list = LETTERS) %>%
step_indicator_column(cc) %>%
step_categorical_column_with_identity(e, num_buckets = 11, default_value = 10) %>%
step_indicator_column(e)
spec_prep <- fit(spec)
ds <- dataset_use_spec(ds, spec_prep)
input <- layer_input_from_dataset(ds)
output <- input %>%
layer_dense_features(dense_features(spec_prep)) %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dense(units = 1, activation = "sigmoid")
model <- keras_model(input, output)
model %>% compile(
loss = loss_binary_crossentropy(),
optimizer = "rmsprop",
metrics = "binary_accuracy"
)
history <- model %>% fit(ds, epochs = 1)
save_model_tf(object = model, filepath = "tf_models/1")
new_m <- load_model_tf("tf_models/1/") # Error if `normalizer_fn = scalar_standard()` above
Fixing this will require adding a get_config()
method to Step
, and likely each of the R6 sub classes as well.
This is tricky, because the R6 object is not passed to the python side ever, but rather, an R closure that is produced on demand with Step$fun()
. That means that we'll have to start passing python objects that wrap the R clousure in an object that exposes the R closure in __call__
and also offer a get_config()
method.
Once get_config()
works, from_config()
will also require changes. None of of the Step
classes currently can be instantiate as fitted objects already, so we'd need to redefine the initializer for each of the R6 classes that allows for bypassing a fit()
call.
This will also likely required exporting a new symbol that can be passed to load_model_tf(..., custom_objects = )
.
https://github.com/rstudio/keras/issues/1231