uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.78k stars 285 forks source link

Simplify data conversion from Spark to PyTorch DataLoader #505

Closed liangz1 closed 4 years ago

liangz1 commented 4 years ago

What changes are proposed in this PR?

Add converter.make_torch_dataloader() with advanced params.

The latest API

def make_spark_converter(
        df,
        parquet_row_group_size_bytes=DEFAULT_ROW_GROUP_SIZE_BYTES,
        compression_codec=None):
    """
    Convert a spark dataframe into a :class:`SparkDatasetConverter` object.
    It will materialize a spark dataframe to the directory specified by
    spark conf 'petastorm.spark.converter.parentCacheDirUrl'.
    The dataframe will be materialized in parquet format, and we can specify
    `parquet_row_group_size_bytes` and `compression_codec` for the parquet
    format. See params documentation for details.

    The returned `SparkDatasetConverter` object will hold the materialized
    dataframe, and can be used to make one or more tensorflow datasets or
    torch dataloaders.

    We can explicitly delete the materialized dataframe data, see
    `SparkDatasetConverter.delete`, and when the spark application exit,
    it will try best effort to delete the materialized dataframe data.

    :param df: The :class:`DataFrame` object to be converted.
    :param parquet_row_group_size_bytes: An int denoting the number of bytes
        in a parquet row group when materializing the dataframe.
    :param compression_codec: Specify compression codec.
        It can be one of 'uncompressed', 'bzip2', 'gzip', 'lz4', 'snappy', 'deflate'.
        Default None. If None, it will leave the data uncompressed.

    :return: a :class:`SparkDatasetConverter` object that holds the
        materialized dataframe and can be used to make one or more tensorflow
        datasets or torch dataloaders.
    """

class SparkDatasetConverter(object):
    """
    A `SparkDatasetConverter` object holds one materialized spark dataframe and
    can be used to make one or more tensorflow datasets or torch dataloaders.
    The `SparkDatasetConverter` object is picklable and can be used in remote
    processes.
    See `make_spark_converter`
    """

    PARENT_CACHE_DIR_URL_CONF = 'petastorm.spark.converter.parentCacheDirUrl'

    def __init__(self, cache_dir_url, dataset_size):
        """
        :param cache_dir_url: A string denoting the path to store the cache
            files.
        :param dataset_size: An int denoting the number of rows in the
            dataframe.
        """

    def __len__(self):
        """
        :return: dataset size
        """

    def make_tf_dataset(self):
        """
        Make a tensorflow dataset.

        This method will do the following two steps:
          1) Open a petastorm reader on the materialized dataset dir.
          2) Create a tensorflow dataset based on the reader created in (1)

        :return: a context manager for a `tf.data.Dataset` object.
                 when exit the returned context manager, the reader
                 will be closed.
        """

    def make_torch_dataloader(self,
                              batch_size=32,
                              num_epochs=None,
                              workers_count=None,
                              cur_shard=None,
                              shard_count=None,
                              **petastorm_reader_kwargs):
        """
        Make a PyTorch DataLoader.

        This method will do the following two steps:
          1) Open a petastorm reader on the materialized dataset dir.
          2) Create a PyTorch DataLoader based on the reader created in (1)

        :param batch_size: The number of items to return per batch
        :param num_epochs: An epoch is a single pass over all rows in the
            dataset. Setting ``num_epochs`` to ``None`` will result in an
            infinite number of epochs.
        :param workers_count: An int for the number of workers to use in the
            reader pool. This only is used for the thread or process pool.
            Defaults to None, which means using the default value from
            `petastorm.make_batch_reader()`.
        :param cur_shard: An int denoting the current shard number. Each node
            reading a shard should pass in a unique shard number in the range
            [0, shard_count). shard_count must be supplied as well. Defaults to
            None
        :param shard_count: An int denoting the number of shards to break this
            dataset into. Defaults to None
        :param petastorm_reader_kwargs: all the arguments for
            `petastorm.make_batch_reader()`.

        :return: a context manager for a `torch.utils.data.DataLoader` object.
                 when exit the returned context manager, the reader
                 will be closed.
        """

    def delete(self):
        """
        Delete cache files at self.cache_dir_url.
        """

Example Code (PyTorch)

from petastorm import make_spark_converter
from petastorm.spark import SparkDatasetConverter
import torch

# specify a cache dir first.
# the dir is used to save materialized spark dataframe files
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, 'hdfs:/...')

df1 = ... # `df1` is a spark dataframe

# create a converter from `df1`
# it will materialize `df1` to cache dir.
converter1 = make_spark_converter(df1)

# make a tensorflow dataset from `converter1
with converter1.make_torch_dataloader() as dataloader:
    # the `dataloader` is `torch.utils.data.DataLoader` object
    # we can train/evaluate model on `dataloader`
    # when exit the with context, the reader of the dataloader will be closed
    ...

# delete the cached files of the dataframe.
converter1.delete()
codecov[bot] commented 4 years ago

Codecov Report

Merging #505 into master will not change coverage by %. The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #505   +/-   ##
=======================================
  Coverage   86.17%   86.17%           
=======================================
  Files          81       81           
  Lines        4421     4421           
  Branches      704      704           
=======================================
  Hits         3810     3810           
  Misses        502      502           
  Partials      109      109           

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 31558f2...31558f2. Read the comment docs.

WeichenXu123 commented 4 years ago

To mock make_batch_reader for test, the following simple code should work:

from contextlib import contextmanager
@contextmanager
def mock_make_batch_reader():
    captured_args = []
    import petastorm
    original_make_batch_reader = petastorm.make_batch_reader
    def mock_fn(dataset_url, **kwargs):
        captured_args.append({'dataset_url': dataset_url, **kwargs})
        return original_make_batch_reader(dataset_url, **kwargs)
    petastorm.make_batch_reader = mock_fn
    try:
        yield captured_args
    finally:
        petastorm.make_batch_reader = original_make_batch_reader
with mock_make_batch_reader() as captured_args:
    from petastorm import make_batch_reader
    with make_batch_reader('file:///tmp/t0001', workers_count=18) as reader:
        for i in reader:
                print(i)
    print('get captured args: ' + str(captured_args))
WeichenXu123 commented 4 years ago

Let's wait this PR https://github.com/uber/petastorm/pull/506 merge first, and then reuse some code there.

praateekmahajan commented 4 years ago

@liangz1 @mengxr @selitvin do we know if this PR can leverage the performance boost being offered in #492? If yes, it might be a nice idea to get that merged too, given all tests pass.

Also minor nit in the PR title, shouldn't it be?

Simplify data conversion from Spark to PyTorch DataLoader

WeichenXu123 commented 4 years ago

Looks Good!