Open selitvin opened 4 years ago
This would be useful for my project too, is there anything I can do to help and get this issue closed?
Thank you - indeed a contribution for this issue would help many users. I guess what's need to be done is to check weather the check referenced in the ticket is still needed; if not, then update create_test_scalar_dataset
in test_common.py
to include list-of-lists type and make sure the tests pass. Please feel free to reach out if you are planning to take a shot at this issue.
I'll take a crack at this! The test should be as simple as removing the check, building a Spark DataFrame containing a list of lists, and then try to write it out? And then on the other end, try loading in the list of lists through a PyTorch or similar DataLoader and verifying outputs look good?
Yeah, I think so. It's actually likely that existing tests will automatically try to compare the new field. Perhaps some comparison logic would need to be tweaked to account for the new data type. I would hope that the places that need to be tweaked will show via failing tests.
As usually, it seems to be simple at the first glance, but there is always some can of worms is being ready to be opened somewhere :)
Alright thanks for letting me know. I'll take a first pass at this either this week or early next week, I can post on here again with what I find.
I grabbed the hello world example and modified it like below, but then I realized there is no Codec available for a List, only an Ndarray Codec. If I try the following code, I get ValueError: Unexpected type of list_of_lists feature. Expected ndarray of <class 'numpy.uint8'>. Got <class 'list'>
. Any thoughts on how to proceed?
import torch
from petastorm.pytorch import DataLoader
from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row
from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
from petastorm.etl.dataset_metadata import materialize_dataset
from pyspark.sql.types import IntegerType
import numpy as np
from petastorm.fs_utils import FilesystemResolver
from petastorm import make_reader
from pyspark.sql import SparkSession
def row_generator(x):
"""Returns a single entry in the generated dataset. Return a bunch of random values as an example."""
return {
"id": x,
"list_of_lists": np.random.randint(0, 100, dtype=np.uint8, size=(2, 2)).tolist(),
}
def generate_hello_world_dataset(spark, HelloWorldSchema, output_url, fsystem):
rows_count = 10
rowgroup_size_mb = 256
sc = spark.sparkContext
# Wrap dataset materialization portion. Will take care of setting up spark environment variables as
# well as save petastorm specific metadata
with materialize_dataset(
spark, output_url, HelloWorldSchema, rowgroup_size_mb, filesystem_factory=fsystem
):
rows_rdd = (
sc.parallelize(range(rows_count))
.map(row_generator)
.map(lambda x: dict_to_spark_row(HelloWorldSchema, x))
)
spark.createDataFrame(rows_rdd, HelloWorldSchema.as_spark_schema()).coalesce(1).write.mode(
"overwrite"
).parquet(output_url)
print("write finished")
if __name__ == "__main__":
spark = (
SparkSession.builder.config("spark.driver.memory", "2g").master("local[2]").getOrCreate()
)
sc = spark.sparkContext
# The schema defines how the dataset schema looks like
HelloWorldSchema = Unischema(
"HelloWorldSchema",
[
UnischemaField("id", np.int32, (), ScalarCodec(IntegerType()), False),
UnischemaField("list_of_lists", np.uint8, (2, 2), NdarrayCodec(), False),
],
)
# Create resolver to path where we want to write to
output_url = "{My URL here}"
resolver = FilesystemResolver(output_url, sc._jsc.hadoopConfiguration(), hdfs_driver="libhdfs")
fsystem = resolver.filesystem_factory()
# Create and write out random dataset
generate_hello_world_dataset(spark, HelloWorldSchema, output_url, fsystem)
# Create PyTorch dataloader and ensure read works
with DataLoader(
make_reader(output_url, hdfs_driver="libhdfs"), batch_size=10, shuffling_queue_capacity=1000
) as train_loader:
for i in train_loader:
print(i["id"].shape)
print(i["list_of_lists"].shape)
materialize_dataset
context manager is used to write some petastorm specific metadata for the cases when you want to store tensors in a parquet file.
For implementing list-of-lists support, I would think you rather start with examples/hello_world/external_dataset/generate_external_dataset.py
. It just shows a vanilla pyspark way of writing parquet files. These are the steps I would imagine one would have to take to fully implement list-of-lists support:
external_dataset/generate_external_dataset.py
so it writes out a parquet store with the list-of-lists type.examples/hello_world/external_dataset/pytorch_hello_world.py
to read the dataset with the new type and fail.Unischema.from_arrow_schema
to account for the new field type.
UnischemaField
entities in from_arrow_schema
we can probably leave codec as None
and work around it downstream when things start breaking.TransformSpec
class to declare list-of-lists to numpy-array logic.create_test_scalar_dataset
in test_common.py
to include new field type and make sure all tests pass. This should be more or less it, but you would probably encounter things I overlooked...
To simplify debugging, you can pass reader_pool="dummy"
. Then workers are running on the main thread.
The check in https://github.com/uber/petastorm/blob/master/petastorm/unischema.py#L326 might not be relevant anymore, as pyarrow now supports more data types.
Should extend the tests to use the list of lists and make sure everything works.