rstudio / tfdatasets

R interface to TensorFlow Datasets API
https://tensorflow.rstudio.com/tools/tfdatasets/
34 stars 12 forks source link

tf dataset input_fn: support custom estimators #1

Closed jjallaire closed 7 years ago

jjallaire commented 7 years ago

@terrytangyuan and @kevinushey I am attempting to add support for creating dataset input functions for custom estimators, however am having a hard time getting the correct behavior (I get a variety of errors depending on what I try).

Here is where I do the tensor transformations to yield the input_fn tensors:

https://github.com/rstudio/tfdatasets/blob/feature/custom-estimator-input-fn/R/input_fn.R#L51-L57

Note that the canned estimator transform works fine (in that case I return of tuple consisting of a dict of feature tensors and a response tensor. For custom estimators I can't quite figure out what to return.

Could one or both of you have a look to see what I might be doing wrong? The repro is in the test case here (probably easiest to copy it out of the test case for directory execution in the REPL): https://github.com/rstudio/tfdatasets/blob/feature/custom-estimator-input-fn/tests/testthat/test-input-fn.R#L78

jjallaire commented 7 years ago

Note that the contract of the dataset_map function I am using here is to take a "record" (which is a named list of tensors) and transform it into something that should be yielded from the dataset's iterator. It looks like tensorflow does some inspection of the Python return values and then synthesizes the equivalent pure tensorflow function. The return types seem to be pretty flexible (you can see that for canned estimators I return a tuple with a nested dict and that all works). What I am returning for custom estimators is a simpler data structure but it still seems to fail.

kevinushey commented 7 years ago

Is this the error you were seeing? (Just want to make sure I'm on the same page here)

>   train(classifier, input_fn(dataset, features = -Species, response = Species))
  Error in py_call_impl(callable, dots$args, dots$keywords) : 
  RuntimeError: Evaluation error: ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.. 
jjallaire commented 7 years ago

Yes that's what I'm seeing.

On Tue, Oct 10, 2017 at 1:11 PM, Kevin Ushey notifications@github.com wrote:

Is this the error you were seeing? (Just want to make sure I'm on the same page here)

train(classifier, input_fn(dataset, features = -Species, response = Species)) Error in py_call_impl(callable, dots$args, dots$keywords) : RuntimeError: Evaluation error: ValueError: The last dimension of the inputs to Dense should be defined. Found None..

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/rstudio/tfdatasets/pull/1#issuecomment-335543561, or mute the thread https://github.com/notifications/unsubscribe-auth/AAGXx9MKcPy2fm9W1h23Y5ebkjQoF7FBks5sq6U0gaJpZM4P0Qqi .

kevinushey commented 7 years ago

These are the inputs the custom model function is receiving:

Browse[2]> str(as.list(environment()))
List of 5
 $ features:Tensor("IteratorGetNext:0", shape=(4, ?), dtype=float32, device=/device:CPU:0)
 $ labels  :Tensor("IteratorGetNext:1", shape=(?,), dtype=string, device=/device:CPU:0)
 $ mode    : chr "train"
 $ params  : Named list()
 $ config  :RunConfig

From tfestimators, it looks like the associated similar code is here:

https://github.com/rstudio/tfestimators/blob/master/tests/testthat/test-tf-custom-models.R#L16

and that model function receives:

Browse[1]> str(as.list(environment()))
List of 5
 $ features:Tensor("random_shuffle_queue_DequeueUpTo:1", shape=(?, 4), dtype=float64, device=/device:CPU:0)
 $ labels  :Tensor("random_shuffle_queue_DequeueUpTo:2", shape=(?,), dtype=int32, device=/device:CPU:0)
 $ mode    : chr "train"
 $ params  : Named list()
 $ config  :RunConfig

Some possibilities:

jjallaire commented 7 years ago

I think we should get rid of the string concern by just transforming the iris.csv to have integers rather than strings (I don't see an easy pure TF way to map strings to integers when reading from a CSV).

The shape should be easy enough to work around. There are some tf functions like tf.stack that will transform lists of tensors to other shapes. Maybe try that?

I'm sure that estimators can handle the iterator tensors b/c it handles them fine for canned estimators.

jjallaire commented 7 years ago

I just pushed a change to the CSV file so that it includes integers.

I have an appointment so may not get back to this today. Let me know if you discover anything w/r/t to a transform that's more compatible with what is produced by input_fn.data.frame.

jjallaire commented 7 years ago

A clarifying question for my benefit: are the feature names anywhere in the input_fn returned for custom estimators? I am currently just returning an unnamed list. If you could let me know the p-code for the object structure to be returned I can try to yield that (it's hard to tell from the existing code in tfestimators b/c of the various levels of indirection).

jjallaire commented 7 years ago

@terrytangyuan If you could take a look when you get a chance that would be appreciated as well (I'm still poking at it but my lack of familiarity with estimators is a handicap)

terrytangyuan commented 7 years ago

Just saw the issue. I am looking at it now!

terrytangyuan commented 7 years ago

I tried to stack and reshape features here and the error you are seeing is going away. This is likely caused by numpy_input_fn(). I've seen similar issue before (see last comment of this issue).

However, I am now seeing different errors that I haven't seen before (from tfdatasets):

2017-10-10 17:05:04.599195: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Field 0 in record 0 is not a valid float: Sepal.Length
     [[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4)]]
 Show Traceback

 Rerun with Debug
 Error in py_call_impl(callable, dots$args, dots$keywords) : 
  InvalidArgumentError: Field 0 in record 0 is not a valid float: Sepal.Length
     [[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[4,?], [?]], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]] 
jjallaire commented 7 years ago

The issue may be that we are reading the col names as the first record, I'll take a look at this in a bit On Tue, Oct 10, 2017 at 5:10 PM Yuan (Terry) Tang notifications@github.com wrote:

I tried to stack and reshape features here https://github.com/rstudio/tfdatasets/commit/c0b1f526172ff3918759550fb55a922ae27bb9dc and the error you are seeing is going away. This is likely caused by numpy_input_fn(). I've seen similar issue before (see last comment of this issue https://github.com/rstudio/tfestimators/issues/56).

However, I am now seeing different errors that I haven't seen before (from tfdatasets):

2017-10-10 17:05:04.599195: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Field 0 in record 0 is not a valid float: Sepal.Length [[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4)]] Show Traceback

Rerun with Debug Error in py_call_impl(callable, dots$args, dots$keywords) : InvalidArgumentError: Field 0 in record 0 is not a valid float: Sepal.Length [[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4)]] [[Node: IteratorGetNext = IteratorGetNextoutput_shapes=[[4,?], [?]], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/rstudio/tfdatasets/pull/1#issuecomment-335609154, or mute the thread https://github.com/notifications/unsubscribe-auth/AAGXxxz83Dv1n8J-gr4YziqXM9dnWmz8ks5sq905gaJpZM4P0Qqi .

jjallaire commented 7 years ago

Okay, I figured out the issue with the invalid types (it was indeed the reading of the column names).

So my next question is whether we can make numpy_input_fn and input functions from datasets behave consistently (i.e. we shouldn't need special transformations when dealing with dataset based input functions). @terrytangyuan Is there a way to take the workaround that you found and apply it generically when we construct the input_fn for datasets?

jjallaire commented 7 years ago

I believe I've figured it out! Here's the commit where I stack the record on the second axis:

https://github.com/rstudio/tfdatasets/pull/1/commits/4a2194d370ab6eaa8d10db0273206dc5a9ccea41

This runs with no errors. @terrytangyuan Could you just double-check that the model is evaluating the correct feature data? Assuming it checks out on your end feel free to merge this.

terrytangyuan commented 7 years ago

Looks good to me! Glad you figured it out!!

terrytangyuan commented 7 years ago

Regarding skip = 1 though, should we change the default so users won't worry about it? Or throw a better error message than the one I saw?

jjallaire commented 7 years ago

Thanks! The skip=1 is necessary because we are explicitly providing column names (even though there are already column names in the file). Only the user knows whether the file actually has column names so I don't know if there is a good way to improve this error condition. I do document that you need to add skip = 1 in this case though so hopefully most people will never see this.

terrytangyuan commented 7 years ago

Got it! Yep that should be sufficient.