uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.8k stars 284 forks source link

Best way to load folder of images into petastorm data set? #497

Closed filipski closed 4 years ago

filipski commented 4 years ago

Hi,

I have a folder of images in png format and I need to load them into a petastorm data set with given schema. For simplicity let's say it's just a path and encoded image. I tried loading all images into a dataframe with spark.read.format("image").load(image_files) and then storing this dataframe with materialize_dataset. Is this the way to go? I'm not sure if the images end up compressed with png in the parquet files. And sometimes I get weird errors, as it seems that petastorm sorts the columns alphabetically, so I need to create Unischema and my final dataframe with columns ordered by name, too.

Here's a sample code I'm trying now:

#!/usr/bin/env python3

import os, sys
import numpy as np

from pyspark.sql import SparkSession, Row
from pyspark.sql.types import StructField, StructType, IntegerType, BinaryType, StringType, TimestampType

from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField

ROWGROUP_SIZE_MB = 128 # The same as the default HDFS block size

# The schema defines how the dataset schema looks like
ImageSchema = Unischema('ImageSchema', [
    UnischemaField('path', np.string_, (), ScalarCodec(StringType()), False),
    UnischemaField('image', np.uint8, (1080, 1280, 3), CompressedImageCodec('png'), False)
])

output_url = "file:///tmp/petastorm_ingest_test"
rows_count = 1

def ingest_folder(images_folder, spark):

    image_files = "file:///"+os.path.abspath(images_folder)+"/*.png"
    print(image_files)
    # Read all images at once
    image_df = spark.read.format("image").load(image_files)

    print('Schema of image_df')
    print('--------------------------')
    image_df.printSchema()

    with materialize_dataset(spark, output_url, ImageSchema, ROWGROUP_SIZE_MB):

        set_df = image_df.select(image_df.image.origin.alias('path'), image_df.image.data.alias('image'))

        print('Schema of set_df')
        print('--------------------------')
        set_df.printSchema()
        print(ImageSchema.as_spark_schema())

        print('Saving to parquet')

        """
        set_df.write \
                .mode('overwrite') \
                .parquet(output_url)

        """

        spark.createDataFrame(set_df.rdd, ImageSchema.as_spark_schema()) \
            .coalesce(10) \
            .write \
            .mode('overwrite') \
            .parquet(output_url)

def main():

    # Start the Spark session
    spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[*]').getOrCreate()    
    sc = spark.sparkContext

    # Ingest images and annotations for a given folder
    ingest_folder("../images/", spark)

if __name__ == '__main__':
    main()
selitvin commented 4 years ago

From running your code I see that image_df.printSchema indicates that spark.read.format("image") ends up creating a structure of fields:

root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)

This is currently not an expected format for petastorm (although, it might be a good idea to support it). Please see this example for currently supported pattern. In the example:

Hope this helps.

filipski commented 4 years ago

Yes, this is the default schema when Spark loads images with load() function. But I'm flattening it out to the same one defined in ImageSchema:

root
 |-- path: binary (nullable = true)
 |-- image: binary (nullable = true)

in the following line:

set_df = image_df.select(image_df.image.data.alias('path'), image_df.image.data.alias('image'))

I've just updated the code above to be a bit more verbose about the schemas.

And yes, I relied on the sample you sent but it requires to create an parallelize an RDD, which seems redundant if one already has a dataframe. I want to avoid RDD also due to the fact that I'm ingesting images from a local file system on a single machine, which is not distributed on all the workers and then I will try to save petastorm data set to distributed HDFS. Is there any better way or dict_to_spark_row is strictly required?

filipski commented 4 years ago

FYI, I've now done it the way described in the code below. It seems to work (total size of the parquet files in the output folder is almost identical to the total size of the input png files), but I'd say it's a bit ugly as there's switch between initial data frame -> to the set of dictionaries -> RDD -> output data frame. Is there a way to avoid going through RDD and dict_to_spark_row? I guess one would need to write a UDF similar to dict_to_spark_row and convert original dataframe columns using this UDF and withColumn(), am I right? This would have to be done in a smart way, with no code duplication between that UDF and dict_to_spark_row, as it would be more difficult to maintain.

Additionally, I took a look into https://github.com/uber/petastorm/blob/a61fe13d5932f9bb9ff6a9e7fa8b7c2dfd5016e3/examples/imagenet/generate_petastorm_imagenet.py#L138 where you set .option('compression', 'none'). Is it recommended for data sets containing png files, which are already compressed? What if one data set would contain both png and annotations in JSON in another column?

#!/usr/bin/env python3

import os, sys
import numpy as np

from pyspark.sql import SparkSession, Row
from pyspark.sql.types import StructField, StructType, IntegerType, BinaryType, StringType, TimestampType

from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField

ROWGROUP_SIZE_MB = 128 # The same as the default HDFS block size

# The schema defines how the dataset schema looks like
ImageSchema = Unischema('ImageSchema', [
    UnischemaField('path', np.string_, (), ScalarCodec(StringType()), False),
    UnischemaField('image', np.uint8, (1080, 1280, 3), CompressedImageCodec('png'), False)
])

output_url = "file:///tmp/petastorm_ingest_test"
rows_count = 1

def row_generator(x):
    """Converts single row into a dictionary"""
    print("Creating row for: " + x.path)
    return {'path': x.path,
            'image': np.reshape(x.image, (1080,1280,3))}

def ingest_folder(images_folder, spark):

    image_files = "file:///"+os.path.abspath(images_folder)+"/*.png"
    print(image_files)
    # Read all images at once
    image_df = spark.read.format("image").load(image_files)

    print('Schema of image_df')
    print('--------------------------')
    image_df.printSchema()

    with materialize_dataset(spark, output_url, ImageSchema, ROWGROUP_SIZE_MB):

        set_df = image_df.select(image_df.image.origin.alias('path'), image_df.image.data.alias('image'))

        print('Schema of set_df')
        print('--------------------------')
        set_df.printSchema()
        print(ImageSchema.as_spark_schema())

        print('Saving to parquet')
        print('--------------------------')

        rows_rdd = set_df.rdd\
            .map(row_generator)\
            .map(lambda x: dict_to_spark_row(ImageSchema, x))

        spark.createDataFrame(rows_rdd, ImageSchema.as_spark_schema()) \
            .coalesce(10) \
            .write \
            .mode('overwrite') \
            .parquet(output_url)

def main():

    # Start the Spark session
    spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[*]').getOrCreate()    
    sc = spark.sparkContext

    # Ingest images and annotations for a given folder
    ingest_folder("../images/", spark)

if __name__ == '__main__':
    main()
filipski commented 4 years ago

To overload you with information/questions :) - I created second version loading images with OpenCV, a sort of simplified version of https://github.com/uber/petastorm/blob/master/examples/imagenet/generate_petastorm_imagenet.py I'm still not fan of loading the images to an RDD first, but this version runs roughly twice as fast as the original ingest_folder() above. Is it really the correct/best way to ingest images from a local file system into petastorm data set?


from glob import glob
import cv2

def ingest_folder2(images_folder, spark):

    # List all images in the folder
    image_files = sorted(glob(os.path.join(images_folder, "*.png")))

    with materialize_dataset(spark, output_url, ImageSchema, ROWGROUP_SIZE_MB):

        input_rdd = spark.sparkContext.parallelize(image_files) \
            .map(lambda image_path:
                    {ImageSchema.path.name: image_path,
                     ImageSchema.image.name: cv2.imread(image_path)})

        rows_rdd = input_rdd.map(lambda r: dict_to_spark_row(ImageSchema, r))

        spark.createDataFrame(rows_rdd, ImageSchema.as_spark_schema()) \
            .coalesce(10) \
            .write \
            .mode('overwrite') \
            .option('compression', 'none') \
            .parquet(output_url)
selitvin commented 4 years ago

Very interesting - thank you for the information and questions ! Sorry I missed the fattening of the dataframe in your original example...

I agree that it would be nice to be able to use dataframes directly, however, I would imagine the performance benefit from staying in the DataFrame land would be due to the ability to do all processing in scala. Petastorm, however is more oriented towards consumers of the data (deep-learning frameworks) that are all python, so if we want to stick with the main code base, we would need to use python UDFs. I am not very familiar with pyspark/spark internals, but I would imagine that if we do that:

One interesting idea could be is to store the DataFrame with images without materialize_dataset - just store bitstreams with png. Then you could use make_batch_reader to read the data directly from parquet. You would probably need to write custom code to decompress images, or maybe we could create some magic to do it for you as part of the petastorm code.

Regarding compression: of course it depends on your data. However, in cases that I have encountered, image size dominate any metadata you want to store together with the images, so it is not compressing the data (on top of png) was the right thing for me to do.

filipski commented 4 years ago

Thanks for the follow-up! Indeed it looks like a lot of changes with no promise on performance gain, especially that just comparing execution time of the two ingest_folder functions above on the same image set, I find the dataframe version almost two times slower than the pure RDD one. But that's probably not fair comparison, as the data is moved from a data frame to an RDD there anyway with a dictionary in between, so probably it wouldn't be that slow if that's avoided.

I wonder how does the pure RDD approach scales. I believe that's how you ingest your data, right? Or do you first distribute images to all nodes on the cluster in some way before reading them into an RDD?

And the compression - maybe it makes sense to introduce some standard codec compressing strings with e.g. lz4/zlib/bzip etc. which could be applied on column level, like CompressedImageCodec?

selitvin commented 4 years ago

In our scenarios performance of writing the data was not critical since it is done infrequently, but I agree, if write path performance/memory-efficiency becomes important, the pyspark-ish rdd approach is bad on multiple levels :)

I did not have a need to ingest images. In my work-scenario, data comes from other sources.

Unfortunately, I was not able to find a way to specify parquet compression per column (was looking for this awhile ago). This is a spark/parquet integration domain issue though and Petastorm can't help here, at least in my understanding.

filipski commented 4 years ago

I did not have a need to ingest images. In my work-scenario, data comes from other sources.

I was under impression that you use a lot of visual data in Uber. Do you happen to have some recommendation how to handle it with petastorm e.g. from other teams?

As for compression - I was rather thinking to specify a compression codec in UnischemaField, like you do for PNGs. So the compression wouldn't be done on the parquet level but during a call to dict_to_spark_row, but then I guess it would have to be decoded manually while reading it out, right?

(BTW, I'm closing the issue, as there's nothing to implement on petastorm side here, but we ca keep the discussion open)

filipski commented 4 years ago

Heh, I have crashes with the same code just trying to store data on HDFS. Please take a look at https://github.com/uber/petastorm/issues/502

selitvin commented 4 years ago

I did not have a need to ingest images. In my work-scenario, data comes from other sources.

Sorry, it came out wrong. We do work a lot with images. However, I never had to ingest them from a large set of png files, like you show in the original example, hence simplifying/optimizing the spark.read.format("image").load(image_files) path was never a priority to us.