rstudio / tfdatasets

R interface to TensorFlow Datasets API
https://tensorflow.rstudio.com/tools/tfdatasets/
34 stars 12 forks source link

dataset_padded_batch using FixedLenSequenceFeature #3

Closed rboubela closed 6 years ago

rboubela commented 6 years ago

when using TFRecords for sequence data with a context feature of shape [1, ] and a sequence feature of shape [NULL, 6], it is pretty straight forward in python to get a padded_batch:

dataset = dataset.padded_batch(batch_size=4, padded_shapes=([1, ], [None, 6]), padding_values=None)

however, when using the following example with a FixedLenSequenceFeature I'm failing to specify the _paddedshapes correctly:

`library(tfdatasets)

tf_filenames <- "some.tfrecord"

dataset <- tfrecord_dataset(tf_filenames) %>% dataset_map(function(example_proto) { context_features <- list( id = tf$FixedLenFeature(shape(list(1)), tf$string) ) sequence_features = list( time_series = tf$FixedLenSequenceFeature(shape(list(6)), tf$float32) ) tf$parse_single_sequence_example(example_proto, context_features, sequence_features) }) %>% dataset_padded_batch(batch_size = 8, padded_shapes = <???>)`

Any idea or suggestion would be much appreciated! Thx!

jjallaire commented 6 years ago

You should be able to do this:

dataset_padded_batch(batch_size = 8, padded_shapes = list(list(1), list(NULL, 6)))

If that doesn't work if you could provide a full reproducible example that fails and I'll investigate further.

rboubela commented 6 years ago

Thanks for your fast reply! Using your suggestion with the provided example (using tensorflow 1.4.1 and at this stage independent of the actual input file) I'm getting the following error:

library(tfdatasets)

tf_filenames <- "some.tfrecord"

dataset <- tfrecord_dataset(tf_filenames) %>%
  dataset_map(function(example_proto) {
    context_features <- list(
      id = tf$FixedLenFeature(shape(list(1)), tf$string)
    )
    sequence_features = list(
      time_series = tf$FixedLenSequenceFeature(shape(list(6)), tf$float32)
    )
    tf$parse_single_sequence_example(example_proto, context_features, sequence_features)
  }) %>%
  dataset_padded_batch(batch_size = 8, padded_shapes = list(list(1), list(NULL, 6)))

Error in py_call_impl(callable, dots$args, dots$keywords) : TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <type 'list'>.

jjallaire commented 6 years ago

Okay, just to allow me to most productively troubleshoot and fix this, could you provide code that I can execute to reproduce the error (e.g. provide the "some.tfrecord" file as well). The best format for doing this is probably a Gist (https://gist.github.com/)

rboubela commented 6 years ago

Thanks for your patience! https://gist.github.com/rboubela/853a104e70b8ae580e94159aab2de454

jjallaire commented 6 years ago

Thanks, could you also update the gist with a .py file that works correctly for the same example?

rboubela commented 6 years ago

Of course. The gist is updated with TFRecordDataset_padded_batch.py added

jjallaire commented 6 years ago

With a commit I just made to tfdatasets the following code runs (however yields a warning, not sure if this is just b/c of the nature of the data in some.tfrecord):

library(tfdatasets)
library(tensorflow)
library(zeallot)

tf_filenames <- "some.tfrecord"

dataset <- tfrecord_dataset(tf_filenames) %>%
  dataset_map(function(example_proto) {
    context_features <- list(
      id = tf$FixedLenFeature(shape(1), tf$string)
    )
    sequence_features = list(
      time_series = tf$FixedLenSequenceFeature(shape(6), tf$float32)
    )
    c(context, sequence) %<-% tf$parse_single_sequence_example(
        serialized = example_proto,
        context_features = context_features,
        sequence_features = sequence_features
    )

    list(context[["id"]], sequence[["time_series"]])

  }) %>%

  dataset_padded_batch(batch_size = 2, padded_shapes = tuple(list(1), list(NULL, 6)), padding_values = NULL)

batch <- next_batch(dataset)
batch

sess <- tf$Session()
sess$run(batch)

Mostly what I did was to try to make the R code exactly like the Python code.

Even if this works for you I don't consider the issue closed as the user shouldn't need to explicitly invoke tuple() on the padded_shapes argument.

rboubela commented 6 years ago

Thank you very much, the new version of tfdatasets works perfectly fine! However, I could not reproduce the warning.

jjallaire commented 6 years ago

Okay, I just make a commit which enables you to omit the call to tuple, so this works:

padded_shapes = list(list(1), list(NULL, 6))