rstudio / tfdatasets

R interface to TensorFlow Datasets API
https://tensorflow.rstudio.com/tools/tfdatasets/
34 stars 12 forks source link

Compat with TF 2.0 #55

Closed atroiano closed 5 years ago

atroiano commented 5 years ago

[Update] This is def an issue

This is more of a question than an issue. I am running into issues with existing feature spec after upgrading to the TensorFlow R package to support TF 2.0 and upgrading to TF 2.0 from TF 2.0 RC1 Is it expected this package will work as-is with TF 2.0 or are updates coming?

Basically embedding layers don't appear to be working.

The error below is what I get:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  InvalidArgumentError: Value for attr 'T' of float is not in the list of allowed values: int32, int64
    ; NodeDef: {{node TruncatedNormal}}; Op<name=TruncatedNormal; signature=shape:T -> output:dtype; attr=seed:int,default=0; attr=seed2:int,default=0; attr=dtype:type,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE]; attr=T:type,allowed=[DT_INT32, DT_INT64]; is_stateful=true> [Op:TruncatedNormal] 

#Below is what the embedding layer produces in the ft_spec, looks, I had to cut some of it off for #privacy reasons.

 dimension=8.0, combiner='mean', initializer=<tensorflow.python.ops.init_ops.TruncatedNormal>, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=0.001, trainable=True)
atroiano commented 5 years ago

The error occurs at step

layer_dense_features(ft_spec$dense_features())
dfalbel commented 5 years ago

It should work. I'll take a look ASAP

atroiano commented 5 years ago

Thanks, I'll monitor the thread all day if you need me to test/ provide more info.

dfalbel commented 5 years ago

@atroiano I tested with the code below and everything works fine with 2.0. Does the code below also fails for you?

library(keras)
library(dplyr)
library(tfdatasets)

spec <- feature_spec(hearts, target ~ .) %>% 
  step_categorical_column_with_vocabulary_list(thal) %>% 
  step_embedding_column(thal, dimension = 2)
spec <- fit(spec)

input <- layer_input_from_dataset(hearts %>% select(-target))
output <- input %>% 
  layer_dense_features(feature_columns = dense_features(spec)) %>% 
  layer_dense(units = 1, activation = "sigmoid")

model <- keras_model(input, output)

model %>% 
  compile(
    loss = "binary_crossentropy", 
    optimizer = "adam",
    metrics = "accuracy"
  )

model %>% 
  fit(
    x = hearts %>% select(-target), y = hearts$target,
    validation_split = 0.2
  )
atroiano commented 5 years ago

I get the following error

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  InvalidArgumentError: Value for attr 'T' of float is not in the list of allowed values: int32, int64
    ; NodeDef: {{node TruncatedNormal}}; Op<name=TruncatedNormal; signature=shape:T -> output:dtype; attr=seed:int,default=0; attr=seed2:int,default=0; attr=dtype:type,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE]; attr=T:type,allowed=[DT_INT32, DT_INT64]; is_stateful=true> [Op:TruncatedNormal]
atroiano commented 5 years ago

I did upgrade to TF 2 from TF RC by pip install upgrade, I wonder if I should use the TensorFlow R package to create a new env.

I just tested a clean install using the TF package and the same error occured

dfalbel commented 5 years ago

Yeah, just reinstalled TensorFlow and could reproduce, sorry! Will take a look now.

atroiano commented 5 years ago

@dfalbel I just reinstalled tfdatasets from GitHub and it updated my reticulate and I was able to run the example you provided without any issues.

dfalbel commented 5 years ago

Found that the issue is in

step_embedding_column(thal, dimension = 2)

The dimension is not transformed to integer before passing to TensorFlow. So it should work by default (since the default is an integer) but won't work if you pass a numeric value. passing dimension = 2L fixes the issue.

I'll push a fix for this.

dfalbel commented 5 years ago

Seems like I had already fixed this 24 days ago :P Didnt remember at all. https://github.com/rstudio/tfdatasets/commit/4b11ee3767cfa3ff56265bd49df2171fbfe65a6e#diff-50ca450fec5925cc3e0e1935e33a999bR783

atroiano commented 5 years ago

Ha! I've been there before.