uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.8k stars 284 forks source link

iterator from make_reader hangs after 10 epochs even if num_epochs=None #396

Closed sdegryze closed 5 years ago

sdegryze commented 5 years ago

In the example below, I'm iterating over the MNIST parquet file generated by the MNIST example

num_epochs is set to None to get the "infinite number of epochs" behavior.

However, the iterator loop hangs after 600 batches. Given that the parquet file contains 60000 records and I requested 1000 as my batch size, this corresponds to 10 epochs.

Code to reproduce:

import tensorflow as tf
from petastorm import make_reader
from petastorm.tf_utils import make_petastorm_dataset

def streaming_parser(serialized_example):   
    image_data = tf.cast(tf.reshape(serialized_example.image, [784]), tf.float32)
    label = serialized_example.digit
    return {"image_data": image_data}, label

with make_reader("s3://path/to/train",
                 num_epochs=None,
                 cur_shard=0,
                 shard_count=1) as reader:
    exp_dataset = (make_petastorm_dataset(reader)
                   .map(streaming_parser)
                   .batch(1000))

features, labels = exp_dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    cum_count = 0
    for idx in range(610):
        labels_manifested = sess.run([labels])
        count = labels_manifested[0].shape[0]
        cum_count += count
        print(f"Batch {idx}, contains {count} records, total records read is {cum_count}")
selitvin commented 5 years ago

Please try adjusting the scope of the make_reader context manager. I think that in your case, the reader is actually stopped because you exit the context manager, which triggers earlier then expected epoch termination.

This code works for me:

import tensorflow as tf
from petastorm import make_reader
from petastorm.tf_utils import make_petastorm_dataset

def streaming_parser(serialized_example):
    image_data = tf.cast(tf.reshape(serialized_example.image, [784]), tf.float32)
    label = serialized_example.digit
    return {"image_data": image_data}, label

with make_reader("file:///tmp/mnist/train",
                 num_epochs=1000,
                 workers_count=1) as reader:
    exp_dataset = (make_petastorm_dataset(reader)
                   .map(streaming_parser)
                   .batch(1000))

    features, labels = exp_dataset.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        cum_count = 0
        for idx in range(610):
            labels_manifested = sess.run([labels])
            count = labels_manifested[0].shape[0]
            cum_count += count
            print(f"Batch {idx}, contains {count} records, total records read is {cum_count}")
sdegryze commented 5 years ago

Okay - that makes sense. It's working with the expanded context manager scope. Thanks.

HaneulKim214 commented 1 year ago

Please try adjusting the scope of the make_reader context manager. I think that in your case, the reader is actually stopped because you exit the context manager, which triggers earlier then expected epoch termination.

This code works for me:

import tensorflow as tf
from petastorm import make_reader
from petastorm.tf_utils import make_petastorm_dataset

def streaming_parser(serialized_example):
    image_data = tf.cast(tf.reshape(serialized_example.image, [784]), tf.float32)
    label = serialized_example.digit
    return {"image_data": image_data}, label

with make_reader("file:///tmp/mnist/train",
                 num_epochs=1000,
                 workers_count=1) as reader:
    exp_dataset = (make_petastorm_dataset(reader)
                   .map(streaming_parser)
                   .batch(1000))

    features, labels = exp_dataset.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        cum_count = 0
        for idx in range(610):
            labels_manifested = sess.run([labels])
            count = labels_manifested[0].shape[0]
            cum_count += count
            print(f"Batch {idx}, contains {count} records, total records read is {cum_count}")

Hi, Selitvin. I'm wondering if there is a safe way to use dataset without context manager. As I need to input dataset into multiple different frameworks I don't want all the code to be inside a context manager.