rstudio / tfdatasets

R interface to TensorFlow Datasets API
https://tensorflow.rstudio.com/tools/tfdatasets/
34 stars 12 forks source link

use feature_spec with TF Probability #46

Closed atroiano closed 5 years ago

atroiano commented 5 years ago

I am creating a model that is outputting a probability distribution based on the sample in this article https://blogs.rstudio.com/tensorflow/posts/2019-06-05-uncertainty-estimates-tfprobability/

I set up a basic test using tfdatasets and I can get the model to train but I don't know how to get it to take the feature_spec and get the right input to score.

library(keras)
library(tfdatasets)
library(tidyverse)
library(dplyr)
library(data.table)
library(tfprobability)
library(tensorflow)
library(ggplot2)

negloglik <- function(y, model){- (model %>% tfd_log_prob(y))}

prior_trainable <- function(kernel_size,
                            bias_size = 0,
                            dtype = NULL) {
  n <- kernel_size + bias_size
  keras_model_sequential() %>%
    layer_variable(n, dtype = dtype, trainable = TRUE) %>%
    layer_distribution_lambda(function(t) {
      tfd_independent(tfd_normal(loc = t, scale = 1),
                      reinterpreted_batch_ndims = 1)
    })
}
posterior_mean_field <-
  function(kernel_size, bias_size = 0, dtype = NULL) {
    n <- kernel_size + bias_size
    c <- log(expm1(1))
    keras_model_sequential(list(
      layer_variable(shape = 2 * n, dtype = dtype),
      layer_distribution_lambda(
        make_distribution_fn = function(t) {
          #independance is an assumption
          tfd_independent(tfd_normal(
            loc = t[1 : n],
            scale = 1e-5 + tf$nn$softplus(c + t[(n + 1) : (2 * n)])
          ), reinterpreted_batch_ndims = 1)
        }
      )
    ))
  }

data_train = data.frame(a=rep(sample(1000),2),b=rep(sample(1000),2))
data_test = data.frame(a=rep(sample(100),2),b=rep(sample(100),2))
n <- nrow(data_train)

ft_spec_1<- data_train %>% 
  feature_spec(b ~.) %>%
  step_numeric_column(one_of(c('a')),normalizer_fn = scaler_standard()) %>%
  fit()

inputs <- layer_input_from_dataset(data_train %>% select(a))

dense_layers <- 
  inputs %>% 
  layer_dense_features(ft_spec_1$dense_features()) %>%  
  layer_dense(units = 5, activation = "relu") %>%
  layer_dense_variational(
    units = 2,
    make_posterior_fn = posterior_mean_field,
    make_prior_fn = prior_trainable,
    kl_weight = 1 / n,
    activation = "linear"
  ) %>%
  layer_distribution_lambda(function(x)
    tfd_normal(loc = x[, 1, drop = FALSE],
               scale = .001 + tf$math$softplus(0.01 * x[, 2, drop = FALSE])), name = str_c('output_', '1'))

model <- keras_model(inputs, dense_layers)
model %>%
  compile(
    loss = negloglik,
    optimizer = optimizer_adam(),
    metrics = 'mae'
  )
model %>%
  fit(
    x = data_train,
    y = data_train$b,
    epochs = 10,
    validation_data = list(data_test, data_test$b),
    batch_size = 512
  )

score_test_data <- model(data_test)
#Error: Invalid input to layer function (must be a model or a tensor)

I can manually define the shape of the tensor I need to score like below and it outputs the prediction I need. Though, this defeats one of the purposes of tfdatasets because I would need to get the mean and sd of the dataset and manually apply the scaling. This is compounded when I have embedding layers.

score_test_data <- model(tf$constant(as.matrix(data_test %>% select(a)),dtype=tf$float32))

Is there a way to take a dataset and pass it to model() while applying the known feature_spec?

atroiano commented 5 years ago

looks like I can do something like this

scoring<-tensors_dataset(data_test)
scoring<-dataset_use_spec(scoring,ft_spec_1)
iter <- make_iterator_one_shot(scoring)
score_this<- iterator_get_next(iter)
scored_test_data <- model(score_this[[1]])

but the preprocessing is not applied to the dataset.

Unrelated to my example above

library(tfdatasets)
data(hearts)
file <- tempfile()
writeLines(unique(hearts$thal), file)
hearts <- tensor_slices_dataset(hearts) %>% dataset_batch(32)

# use the formula interface
spec <- feature_spec(hearts, target ~ thal) %>%
  step_categorical_column_with_vocabulary_list(thal) %>%
  step_embedding_column(thal, dimension = 3)
spec_fit <- fit(spec)
final_dataset <- hearts %>% dataset_use_spec(spec_fit)
iter <- make_iterator_one_shot(final_dataset)
score_this<- iterator_get_next(iter)
score_this

This is another example where dataset_use_spec does not appear to be applying any transformation to the data.

dfalbel commented 5 years ago

Can you try changing your last line to:

score_test_data <- model(as.list(data_test %>% select(a)))
atroiano commented 5 years ago

says it can't have rank 0

score_test_data <- model(as.list(data_test %>% select(a)))

 Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Feature (key: a) cannot have rank 0. Give: 36 
dfalbel commented 5 years ago

@atroiano As a workaround you can do this: model(keras_array(data_test)). I'll figure out how make the keras_array casll default into Keras.

atroiano commented 5 years ago

@dfalbel That works well for getting data scored without scaling it (I don't see where the scaling would happen in this instance).

It appears the following code will work for scoring as well.

scoring<-tensors_dataset(data_test)
scoring<-dataset_use_spec(scoring,ft_spec_1)
iter <- make_iterator_one_shot(scoring)
score_this<- iterator_get_next(iter)
scored_test_data <- model(score_this[[1]])

When this runs, I am assuming the dataset_use_spec will apply the feature_spec transformation to my dataset based on the dataset it was fit on but it does not appear to happen

step_numeric_column(one_of(c('a')),normalizer_fn = scaler_standard())

score_this[[1]] is the right object to pass but the data in there is not scaled, like I outlined in the step above.

I have a more complex example that is using embedding layers and I am running into the same issue only the columns are not being converted based on the vocab.

dfalbel commented 5 years ago

This should actually scale the inputs since layer_dense_features adds the transformations to the graph. Eg.:

library(tfdatasets)
library(keras)

df <- data.frame(
  x = 1:10,
  y = 1:10
)

spec <- feature_spec(df, y ~ x) %>% 
  step_numeric_column(x, normalizer_fn = scaler_standard())

spec <- fit(spec)

inputs <- layer_input_from_dataset(df)
output <- layer_dense_features(inputs, feature_columns = spec$dense_features())

model <- keras_model(inputs, output)

model(keras_array(df))
atroiano commented 5 years ago

How do you get this example to work with embedding columns?

library(tfdatasets)
library(keras)

df <- data.frame(
  x = 1:10,
  z = c(rep('b',5),rep('a',5)),
  y = 1:10
)

k_clear_session()
spec <- feature_spec(df, y ~ x+z) %>% 
  step_numeric_column(x, normalizer_fn = scaler_standard()) %>% 
  step_categorical_column_with_vocabulary_list(z) %>% 
  step_embedding_column(z) %>% fit()

inputs <- layer_input_from_dataset(df)
outputs <- 
  inputs %>% 
  layer_dense_features(spec$dense_features())  %>% 
  layer_dense(units=1)

model <- keras_model(inputs, outputs)

model(keras_array(df))

does not work, says

error in py_get_attr_impl(x, name, silent) : 
  AttributeError: 'list' object has no attribute 'dtype'
dfalbel commented 5 years ago

I think the problem here is just related to how we deal with factors. Works as expected if you set:

df <- data.frame(
  x = 1:10,
  z = c(rep('b',5),rep('a',5)),
  y = 1:10,
  stringsAsFactors = FALSE
)
atroiano commented 5 years ago

I get the error: Error in py_call_impl(callable, dots$args, dots$keywords) : ValueError: Column dtype and SparseTensors dtype must be compatible. key: z, column dtype: <dtype: 'string'>, tensor dtype: <dtype: 'int32'>

library(tfdatasets)
library(keras)

df <- data.frame(
  x = 1:10,
  z = c(rep('b',5),rep('a',5)),
  y = 1:10,
  stringsAsFactors = FALSE
)

k_clear_session()
spec <- feature_spec(df, y ~ x+z) %>% 
  step_numeric_column(x, normalizer_fn = scaler_standard()) %>% 
  step_categorical_column_with_vocabulary_list(z) %>% 
  step_embedding_column(z) %>% fit()

inputs <- layer_input_from_dataset(df)
outputs <- 
  inputs %>% 
  layer_dense_features(spec$dense_features())  %>% 
  layer_dense(units=1)

model <- keras_model(inputs, outputs)

model(keras_array(df))
dfalbel commented 5 years ago

The above code just works for me. What's you TF version?

atroiano commented 5 years ago

2.0.0-beta1

dfalbel commented 5 years ago

Ok, this seems like a bug! A workaround is to call:

inputs <- reticulate::dict(layer_input_from_dataset(df))

Will push a fix to master ASAP

atroiano commented 5 years ago

That workaround outlined above fixed the example, as well as, a more complicated model I have locally.

I really appreciate your help!

dfalbel commented 5 years ago

Should be fixed in master

atroiano commented 5 years ago

Cheers. Thanks again for the quick support!