uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.76k stars 281 forks source link

Bug in ConcurrentVentilator._ventilate() when randomize_item_order=True and random seed is fixed #798

Open JonasRauch opened 11 months ago

JonasRauch commented 11 months ago

Hi Petastorm team,

I noticed that make_tf_dataset with shuffle_row_groups=True generates a non-deterministic order of rows starting in epoch 2, even if I fix the random seed. I am using workers_count=1 to remove randomness from parallel processing (which is addressed separately in #796). I was able to reproduce the problem when using make_batch_reader directly.

Example code to reproduce the problem

  1. Create a Spark DF with 10 rows and use repartition to force multiple data files
  2. Create a Petastorm converter
  3. Set petastorm_converter.file_urls = list(sorted(petastorm_converter.file_urls)) to fix the order of input files. Note: This should probably be done somewhere in the Petastorm library, because without this I saw additional randomness from the order in which Spark lists the data files
  4. Create a Reader with `workers_count=1, seed=1, shuffle_rows=False, shuffle_row_groups=True, num_epochs=2)
  5. Read all batches from the reader into a new Pandas DF
  6. Repeat 4. and 5.
  7. Compare the resulting Dataframes

I get differently ordered results about 50% of the time on my system. Re-running steps 4-7 a few times should eventually lead to the problem

import pandas as pd
import petastorm
import pyspark

def read_all_batches(reader):
    data = []
    for i, batch in enumerate(reader):
        batch = batch._asdict()
        data.append(pd.DataFrame(batch))
    data = pd.concat(data, axis=0, ignore_index=True)
    return data

# create example DF
spark = pyspark.sql.SparkSession.builder.getOrCreate()

df = spark.range(10)
df = df.sort("id")
df = df.repartition(5)

# set Petastorm cache path
spark.conf.set(petastorm.spark.SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, r'file:/\tmp\petastorm')
# create petastorm converter
petastorm_converter = petastorm.spark.make_spark_converter(df)
# fix file order
petastorm_converter.file_urls = list(sorted(petastorm_converter.file_urls))

reader_1 = petastorm.make_batch_reader(petastorm_converter.file_urls, workers_count=1, seed=1, shuffle_rows=False, shuffle_row_groups=True, num_epochs=2)
reader_2 = petastorm.make_batch_reader(petastorm_converter.file_urls, workers_count=1, seed=1, shuffle_rows=False, shuffle_row_groups=True, num_epochs=2)

data_1 = read_all_batches(reader_1)
data_2 = read_all_batches(reader_2)

if all(data_1["id"] == data_2["id"]):
    print("OK")
else:
    print("Not OK")
    print(pd.concat((data_1, data_2), axis=1))

Root cause

I was able to trace back the root cause of this to the _ventilate() method of the ConcurrentVentilator class, specifically this part. The problem is that the shuffle operation at the beginning of an epoch happens before the sleep and continue operation in case the queue has no room.

Example code to directly reproduce the problem with ConcurrentVentilator

This code snippet implements a sub-classed version of ConcurrentVentilator which overloads the _ventilate() method. The logic is an exact copy/paste of the original method, but there are additional print statements inside of the loop to show what is going on.

from petastorm.workers_pool.ventilator import ConcurrentVentilator
from petastorm.workers_pool.thread_pool import ThreadPool
from petastorm.workers_pool.worker_base import WorkerBase
from time import sleep

class DebugConcurrentVentilator(ConcurrentVentilator):
    def _ventilate(self):
        while True:
            # Stop condition is when no iterations are remaining or there are no items to ventilate
            if self.completed():
                break

            # If we are ventilating the first item, we check if we would like to randomize the item order
            if self._current_item_to_ventilate == 0 and self._randomize_item_order:
                print(f"_current_item_to_ventilate: {self._current_item_to_ventilate} - shuffling")
                self._random_state.shuffle(self._items_to_ventilate)

            # Block until queue has room, but use continue to allow for checking if stop has been called
            if self._ventilated_items_count - self._processed_items_count >= self._max_ventilation_queue_size:
                sleep(self._ventilation_interval)
                print(f"_current_item_to_ventilate: {self._current_item_to_ventilate} - waiting for room in the queue")
                continue

            item_to_ventilate = self._items_to_ventilate[self._current_item_to_ventilate]
            print(f"_current_item_to_ventilate: {self._current_item_to_ventilate} - ventilating")
            self._ventilate_fn(**item_to_ventilate)
            self._current_item_to_ventilate += 1
            self._ventilated_items_count += 1

            if self._current_item_to_ventilate >= len(self._items_to_ventilate):
                self._current_item_to_ventilate = 0
                # If iterations was set to None, that means we will iterate until stop is called
                if self._iterations_remaining is not None:
                    self._iterations_remaining -= 1

items = [{'index': i} for i in range(3)]
pool = ThreadPool(1)

ventilator = DebugConcurrentVentilator(
            ventilate_fn=pool.ventilate,
            items_to_ventilate=items,
            iterations=2,
            randomize_item_order=True,
            random_seed=0,
            max_ventilation_queue_size=1
        )

class TestWorker(WorkerBase):
    def __init__(self, worker_id, publish_func, args):
        super().__init__(worker_id, publish_func, args)

    def process(self, index):
        sleep(0.1)
        ventilator.processed_item()

pool.start(TestWorker, ventilator=ventilator)

Example output

_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - ventilating
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - ventilating
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - ventilating
_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - ventilating
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - ventilating
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - ventilating

Note that at the beginning of the second iteration there are 7 shuffle operations before the first item is eventually ventilated.

Proposed fix

This can be fixed easily by first checking whether the queue has room, and only if it does then shuffle and ventilate:


class FixedConcurrentVentilator(ConcurrentVentilator):
    def _ventilate(self):
        while True:
            # Stop condition is when no iterations are remaining or there are no items to ventilate
            if self.completed():
                break

            # Block until queue has room, but use continue to allow for checking if stop has been called
            if self._ventilated_items_count - self._processed_items_count >= self._max_ventilation_queue_size:
                sleep(self._ventilation_interval)
                print(f"_current_item_to_ventilate: {self._current_item_to_ventilate} - waiting for room in the queue")
                continue

            # If we are ventilating the first item, we check if we would like to randomize the item order
            if self._current_item_to_ventilate == 0 and self._randomize_item_order:
                print(f"_current_item_to_ventilate: {self._current_item_to_ventilate} - shuffling")
                self._random_state.shuffle(self._items_to_ventilate)

            item_to_ventilate = self._items_to_ventilate[self._current_item_to_ventilate]
            print(f"_current_item_to_ventilate: {self._current_item_to_ventilate} - ventilating")
            self._ventilate_fn(**item_to_ventilate)
            self._current_item_to_ventilate += 1
            self._ventilated_items_count += 1

            if self._current_item_to_ventilate >= len(self._items_to_ventilate):
                self._current_item_to_ventilate = 0
                # If iterations was set to None, that means we will iterate until stop is called
                if self._iterations_remaining is not None:
                    self._iterations_remaining -= 1

Example output of running the same example code as above but with FixedConcurrentVentilator instead of DebugConcurrentVentilator:

_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - ventilating
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - ventilating
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - ventilating
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - waiting for room in the queue
_current_item_to_ventilate: 0 - shuffling
_current_item_to_ventilate: 0 - ventilating
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - waiting for room in the queue
_current_item_to_ventilate: 1 - ventilating
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - waiting for room in the queue
_current_item_to_ventilate: 2 - ventilating

Note that there is exactly one shuffle operation at the beginning of each iteration.