def _inner():
"""Read tfrecords, this func will be returned"""
dataset = tf.data.Dataset.from_tensor_slices(tfrecord_fpaths)
if shuffle:
dataset = dataset.shuffle(len(tfrecord_fpaths))
dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename,
compression_type="GZIP",
num_parallel_reads=num_parallel_reads,
),
)
deserialize tfexamples
dataset = dataset.map(lambda serialized: tf.parse_single_example(serialized, feature_schema),
num_parallel_calls=num_parallel_calls,
)
for preprocessor in model_preprocessors or []:
dataset = preprocessor.preprocess(dataset, mode)
if shuffle:
dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
example_batch = iterator.get_next()
return example_batch
return _inner
What should I do to get a serials data for a lstm network?I think dataset.window is needed.But I don't know how to add it in this fuction? Can you help me?Thank you!
def _inner(): """Read tfrecords, this func will be returned""" dataset = tf.data.Dataset.from_tensor_slices(tfrecord_fpaths) if shuffle: dataset = dataset.shuffle(len(tfrecord_fpaths)) dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename, compression_type="GZIP", num_parallel_reads=num_parallel_reads, ), )
deserialize tfexamples
What should I do to get a serials data for a lstm network?I think dataset.window is needed.But I don't know how to add it in this fuction? Can you help me?Thank you!