uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.8k stars 284 forks source link

[Tensorflow] Support tf.dataset.repeat() to avoid duplicating and dropping samples in one epoch with shuffle? #673

Open chongxiaoc opened 3 years ago

chongxiaoc commented 3 years ago

Current implementation of make_petastorm_dataset for tensorflow doesn't support multiple iterations:https://github.com/uber/petastorm/blob/7f37e8dde6ff1b13f055d22a6289e2de8bb5d473/petastorm/tf_utils.py#L370

It is recommended to set reader's num_epochs > 1 to support multiple iterations.

This will cause possible duplication and drop and samples in one epoch when using together with tf.dataset.shuffle. Let's say:

For instance, if your dataset contains 10,000 elements but buffer_size is set to 1,000, then shuffle will initially select a random element from only the first 1,000 elements in the buffer. Once an element is selected, its space in the buffer is replaced by the next (i.e. 1,001-st) element, maintaining the 1,000 element buf

- A training epoch will be like:
  - Randomly pick 7, and refill with 0 (next available in reader) again: [0,1,2,3,4,5,6,0]
  - Now it is highly possible 0 will be selected twice in next few iterations, even before other values are selected once. 
  - In the end of an epoch, it is highly possible 0 is duplicated, while some values are dropped.

- This example is reproducible in our real applications, the reason is that by using `tf.dataset.shuffle`, it will always refill shuffle buffer with next available element from reader (since we set reader to be multiple epochs)

- A work-around to fix this:
  -  Set reader num_epochs to be 1.
  -  Enable multiple iterations in `make_petastorm_dataset`
  -  Using shuffle and repeat from tf.dataset as:

dataset = make_petastorm_dataset(reader) dataset = dataset.shuffle(8) dataset = dataset.repeat(num_epochs)


- In this case, `repeat` isolates different epochs, so we are not seeing sample drop and duplication in a single epoch. (verified on our applications).

@selitvin What do you think?