pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.13k stars 151 forks source link

Add a flatmap datapipe #178

Closed erip closed 2 years ago

erip commented 2 years ago

🚀 The feature

Add a FlatMapDataPipe which simulataneously flattens nested IterDataPipes and applies a function to nested pipes.

Motivation, pitch

I have tarballs containing tarballs. When dealing with these types of recursive data, it's useful to have a mechanism to both flatten the structure and apply a function to it so I only need to consider the "inner-most" data.

Alternatives

Domain libraries can define their own impls.

Additional context

Migrating torchtext to datapipes via torchdata would benefit from this feature when dealing with IWSLT data.

erip commented 2 years ago

Maybe another good analog would be flatten which does that same thing without the intermediate map (equivalent, flatmap with identity function)

ejguan commented 2 years ago

We do have a plan to add flatmap but the behavior may be different than the proposal. We expect flatmap would expand the returned list from map_fn.

>>> def fn(data):
...     return [data, data*10]

>>> dp = IterableWrapper([0,1,2,3])
>>> dp = dp.flatmap(fn)
>>> list(dp)
[0, 0, 1, 10, 2, 20, 3, 30]

We do have a operator named unbatch, which behaves same as flatten

erip commented 2 years ago

Hmm, I think maybe you're right and that my implementation is a concatmap. The difference is when the map happens (before for flatmap, after for concatmap). 🤔

ejguan commented 2 years ago

Technical speaking, this flatmap can also be applied in your use case. The function becomes a generator yielding data..

erip commented 2 years ago

Yes, the pseudocode I have in mind is basically

# [ 1.tgz, 2.tgz, ..., n.tgz]
inner_tars = outer_tar.read_from_tar()

# [ 1_1.txt, 1_2.txt, 2_1.txt, 2_2.txt, 3_1.txt, ...]. Need a better way than a lambda
inner_inner_files = inner_tars.flatmap(lambda inner_tar: FileOpener(inner_tar).read_from_tar())

So the flatmap'd read_from_tar becomes the generator yielding data and flatmap flattens it into a IterableDataPipe. I think the test I've written doesn't actually reflect this usecase 😬 but that's a thought for the PR