merantix / imitation-learning

Autonomous driving: Tensorflow implementation of the paper "End-to-end Driving via Conditional Imitation Learning"
https://medium.com/merantix/journey-from-academic-paper-to-industry-usage-cf57fe598f31
MIT License
89 stars 21 forks source link

About input_fn fuction #19

Closed zxbnjust closed 4 years ago

zxbnjust commented 4 years ago

def _inner(): """Read tfrecords, this func will be returned""" dataset = tf.data.Dataset.from_tensor_slices(tfrecord_fpaths) if shuffle: dataset = dataset.shuffle(len(tfrecord_fpaths)) dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename, compression_type="GZIP", num_parallel_reads=num_parallel_reads, ), )

deserialize tfexamples

    dataset = dataset.map(lambda serialized: tf.parse_single_example(serialized, feature_schema),
                          num_parallel_calls=num_parallel_calls,
                          )

    for preprocessor in model_preprocessors or []:
        dataset = preprocessor.preprocess(dataset, mode)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    example_batch = iterator.get_next()
    return example_batch

return _inner

What should I do to get a serials data for a lstm network?I think dataset.window is needed.But I don't know how to add it in this fuction? Can you help me?Thank you!