Closed sdegryze closed 5 years ago
Please try adjusting the scope of the make_reader
context manager. I think that in your case, the reader is actually stopped because you exit the context manager, which triggers earlier then expected epoch termination.
This code works for me:
import tensorflow as tf
from petastorm import make_reader
from petastorm.tf_utils import make_petastorm_dataset
def streaming_parser(serialized_example):
image_data = tf.cast(tf.reshape(serialized_example.image, [784]), tf.float32)
label = serialized_example.digit
return {"image_data": image_data}, label
with make_reader("file:///tmp/mnist/train",
num_epochs=1000,
workers_count=1) as reader:
exp_dataset = (make_petastorm_dataset(reader)
.map(streaming_parser)
.batch(1000))
features, labels = exp_dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
cum_count = 0
for idx in range(610):
labels_manifested = sess.run([labels])
count = labels_manifested[0].shape[0]
cum_count += count
print(f"Batch {idx}, contains {count} records, total records read is {cum_count}")
Okay - that makes sense. It's working with the expanded context manager scope. Thanks.
Please try adjusting the scope of the
make_reader
context manager. I think that in your case, the reader is actually stopped because you exit the context manager, which triggers earlier then expected epoch termination.This code works for me:
import tensorflow as tf from petastorm import make_reader from petastorm.tf_utils import make_petastorm_dataset def streaming_parser(serialized_example): image_data = tf.cast(tf.reshape(serialized_example.image, [784]), tf.float32) label = serialized_example.digit return {"image_data": image_data}, label with make_reader("file:///tmp/mnist/train", num_epochs=1000, workers_count=1) as reader: exp_dataset = (make_petastorm_dataset(reader) .map(streaming_parser) .batch(1000)) features, labels = exp_dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) cum_count = 0 for idx in range(610): labels_manifested = sess.run([labels]) count = labels_manifested[0].shape[0] cum_count += count print(f"Batch {idx}, contains {count} records, total records read is {cum_count}")
Hi, Selitvin. I'm wondering if there is a safe way to use dataset without context manager. As I need to input dataset into multiple different frameworks I don't want all the code to be inside a context manager.
In the example below, I'm iterating over the MNIST parquet file generated by the MNIST example
num_epochs
is set toNone
to get the "infinite number of epochs" behavior.However, the iterator loop hangs after 600 batches. Given that the parquet file contains 60000 records and I requested 1000 as my batch size, this corresponds to 10 epochs.
Code to reproduce: