pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.12k stars 149 forks source link

MapDatapipe Mux/Demux Support #310

Open josiahls opened 2 years ago

josiahls commented 2 years ago

🚀 The feature

MapDatapipes are missing Mux and Demux pipes as noted in https://github.com/pytorch/pytorch/issues/57031

Talked to @ejguan on https://discuss.pytorch.org/t/mapdatapipe-support-mux-demux/146305, I plan to do a PR with Mux/Demux added. However, I will add rough outlines / ideas here first. I plan to match the same test strategy as the Mux/Demux pipes already in IterDataPipes.

Motivation, pitch

For Demux: My basic test/goal is to download mnist, and split it into train/validation sets using map. For Mux: Then attempt to mux them back together (not sure how to come up with a useful example of this).

Not sure when this should be converted to a pr. This would be my first pr into pytorch, so I want the pr to be as clean as possible. Putting code changes ideas here I feel could allow for more dramatic/messy changes/avoid a messy git diff/worry about formatting once code is finalized.

Note: doc strings are removed to make code shorter and will be readded in pr. Not-super-useful comments will be removed in pr.

Note: let me know if a draft pr would be better.

Demux working code: Draft 1: https://github.com/josiahls/fastrl/blob/848f90d0ed5b0c2cd0dd3e134b0b922dd8a53d7c/fastrl/fastai/data/pipes.py

Demux working code + Basic Test Draft 1: https://github.com/josiahls/fastrl/blob/848f90d0ed5b0c2cd0dd3e134b0b922dd8a53d7c/nbs/02c_fastai.data.pipes.ipynb

Mux working code: Draft 1: https://github.com/josiahls/fastrl/blob/30cd47766e9fb1bc75d32de877f54b8de9567c36/fastrl/fastai/data/pipes/mux.py

Basic Test Draft 1: https://github.com/josiahls/fastrl/blob/30cd47766e9fb1bc75d32de877f54b8de9567c36/nbs/02c_fastai.data.pipes.mux.ipynb

josiahls commented 2 years ago

There are a couple of assumptions that seem like can differ between the Iter and Map versions of Demux/Mux. Please let me know if any caveats / assumptions you would expect from a Mux/Demux. I'm also curious about what the default behavior is supposed to be. I am noticing that lazy loading seems to be a default

_Assumption 1 What is an instance_id?_ _ChildDataPipe can assume that the instance id can be non-int, just as long as it is hashable. I feel like assuming the id is hasable would allow you to split the datapipes and know which one is train and which on is valid based on the ids. I could be missing something, but the int requirement I'm assuming is relevant to only iterpipes.

class _ChildDataPipe(dp.map.MapDataPipe):
#    def __init__(self, main_datapipe, instance_id: int):
    def __init__(self, main_datapipe, instance_id: Hashable):

_Assumption 2 Should num_instances be kept?_ Since the len of a Map is known (and "all the data is there already"), the use of num_instances in Demux is no longer needed. It is possible we could have this param still, but only for hinting/validating the split occured as expected? This could also be changed to instance_keys where instance_keys is a Optional list of keys that we are to expect. If instance_keys is none, we lazy determine the the number of instances. If it is not None, we raise an error if there is a key that we did not expect.

One issue with above is that if a demux could output 1->N instances, how would child datapipes be constructed to adjust to this? Maybe if the Demux is set to create the map on init (before iter), then we dont have to set num_instances, but if it is expected to lazy load, then num_instances is required. Maybe this is too complex and num_instances should be required 100% of the time.

Assumption 3 What part does buffers play in all this? Do buffers have any place in Map datapipes?


class _DemultiplexerMapDataPipe(dp.map.MapDataPipe):
...
        self.buffer_size = buffer_size
        if self.buffer_size < 0:
            warnings.warn(
                "Unlimited buffer size is set for `demux`, "
                "please be aware of OOM at random places",
                UserWarning
            )
        self.current_buffer_usage = 0

I could argue that a buffer could be used as a caching behavior for a map. For example reading from a file system, the map could be set to keep N loaded images in memory. The default for the buffer would be zero so that memory is kept lean.

Assumption 4 Does exhaustion make sense anymore? main_datapipe_exhausted and any other reference to data pipe exhaustion no longer makes sense in this context. So what happens when we get to the end of a map? It might be better to have each child store the keys/indexes over the first iteration, and once it reaches the end, we don't store the cached keys.

Assumption 5 How are indexes known? In order to split into child pipes, we might want to cache the indexes for each respective childpipe. In order to do this, we need to be able to know how those indexes are stored in the main datapipe so that they can be retrieved.

@ejguan Am I missing something obvious where there isnt a standard way of getting the indexes for a given MapDataPipe? Given a MapDataPipe, I am always tempted to try calling a keys() method that doesnt exist. Currently Demux needs to do this:

    def _setup_datapipe_indexer(self) -> Optional[Iterator[Any]]:
        # self._datapipe_iterator: Optional[Iterator[Any]] = None
        # Instead of _datapipe_iterator we have _datapipe_indexer
        # We need to know how to get the index from the main_datapipe. In order
        # to do this, we check if it is...

        # NOTE: THIS IS NOT A GOOD SOLUTION SINCE THIS CANT RELY ON A STANDARD
        # INTERFACE FOR GETTING INDEXES

        # We cash the indexes because we want to be able to have consistent behavior 
        # when calling __getitem__ on a child pipe. 
        # What we don't want is the main_datapipe being indexed by `str` but the
        # child pipes indexing by `int`...
        if isinstance(self.main_datapipe, dp.map.SequenceWrapper):
            return range(len(self.main_datapipe))
        elif hasattr(self.main_datapipe, '_map'):
            return iter(self.main_datapipe._map)
        elif hasattr(self.main_datapipe, 'index_map'):
            return iter(self.main_datapipe.index_map)
        else:
            warnings.warn('data pipe will be indexed by len')
            return range(len(self.main_datapipe))

Assumption 6 Maintining index integrity This is going to be an issue for Mux and Demux. Currently for demux, index integrity is maintained in that if I have 1 pipe: [0-20], then demux will split the pipes with indexes: dp1: [0-10], dp2: [10-20].

If this is the case trying dp2[0] will raise and IndexError.

Assumption 7 ChildPipes need to forward index strategy Draft 1 currently has a strange behavior when passing a child pipe into another Demux. You end up losing the index strategy. So if the original pipe has a str index dict, ChildPipe -> Demux will default to range(len(...)) method of indexing. This is undesirable. Reference Assumption 5

ejguan commented 2 years ago

Thanks for all the detail. cc @NivekT for the original author of Iter Demux. I just skim through the assumptions. For the buffer and exhaustion, I don't think Map-style DataPipe needs such variables as the data should be indexable. And, the reason we have such variables is IterDataPipe is a streaming way, it requires us to cache the output from prior DataPipe. I will take a deeper look at your proposal.

NivekT commented 2 years ago

I need more time to think about demux but here is a version of mux that I quickly put together:

from itertools import chain, zip_longest

from typing import Dict, Iterable, Optional

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe

@functional_datapipe("mux")
class MultiplexerMapDataPipe(MapDataPipe):
    def __init__(self, *datapipes, dp_index_map: Optional[Dict[MapDataPipe, Iterable]] = None):
        self.datapipes = datapipes
        self.dp_index_map = dp_index_map if dp_index_map else {}
        self.length: Optional[int] = None
        self.index_map = {}
        # Create a generator that yields (index, (dp_num, old_index)) in sequentially order.
        indices = (self._add_dp_num(i, dp) for i, dp in enumerate(datapipes))
        dp_id_and_key_tuples = chain.from_iterable(zip_longest(*indices))
        self.key_gen = enumerate(e for e in dp_id_and_key_tuples if e is not None)

    def _add_dp_num(self, dp_num: int, dp: MapDataPipe):
        # Assume 0-index for all DataPipes unless alternate indices are defined in `self.dp_index_map`
        dp_indices = self.dp_index_map[dp] if dp in self.dp_index_map else range(len(dp))
        for idx in dp_indices:
            yield dp_num, idx

    def __getitem__(self, index):
        if 0 <= index < len(self):
            if index in self.index_map:
                dp_num, old_key = self.index_map[index]
            else:
                curr_key = -1
                while curr_key < index:
                    curr_key, dp_num_key_tuple = next(self.key_gen)
                    dp_num, old_key = dp_num_key_tuple
                self.index_map[index] = dp_num, old_key
            try:
                return self.datapipes[dp_num][old_key]
            except KeyError:
                raise RuntimeError(
                    f"Incorrect key is given to MapDataPipe {dp_num} in Multiplexer, likely because"
                    f"that DataPipe is not 0-index but alternate indices are not given."
                )
        raise RuntimeError(f"Index {index} is out of bound for Multiplexer.")

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

    def __len__(self):
        if self.length is None:
            self.length = 0
            for dp in self.datapipes:
                self.length += len(dp)
        return self.length

It can be used as such:

a = SequenceWrapper(range(10))
b = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
dp = a.mux(b, dp_index_map={b: ['a', 'b', 'c', 'd']})
list(dp)  # Returns [0, 100, 1, 200, 2, 300, 3, 400, 4, 5, 6, 7, 8, 9]

Nothing is final and happy to discuss how the APIs should be: I think the key difference compared to your version is that this one assume the input DataPipes have 0-index unless an alternate is specified in the input (so it doesn't need to check with isinstance or hasattr). Aside from that, I think it is mostly similar.

NivekT commented 2 years ago
  1. Currently, we require the classifier_fn to return a value in [0, ..., num_instance - 1], basically the instance_id. Every element is classified and assigned to the DataPipe of that instance_id. If we replace instance_id, we will have to take an additional argument that indicates what the possible classifications are and what the ordering should (which would impact the ordering of the return_value: List[MapDataPipe] .

  2. I think knowing num_instances up front is better than trying to figure it out at dynamically, especially since the DataPipes are lazy.

  3. I don't think we need buffer for MapDataPipe.

  4. "Exhaustion" probably mean something different here. It should mean the source DataPipe has been fully read such that we know which child DataPipe each element belongs to.

  5. I am starting to think it might be desirable to have a standard way to retrieve indices from a MapDataPipe (maybe self.indices). It depends on how often do people deviate from 0-index based. At the minimum we should probably add some DataPipe or function that allows people to attach a custom index to a 0-index based MapDataPipe. Without that, the only way is to accept indices as an argument. See the example above in mux.

  6. One option is to reset the Child's indexing strategy so all of them will start from 0. This should be fine for most cases? Especially when the original DataPipe is 0-indexed or the user doesn't know in advance how many elements will be assigned to each child DataPipe.

  7. Perhaps we can allow an option of forwarding the index strategy in demux, but we may want to have a uniform way to access the indices for all DataPipe first.

Here is a buffer-free implementation, with the indices in all children become 0-index again. We can potentially expand on this by allowing index-forwarding, but to do that, we might want to give a standard way of retrieving indices from all or at least some MapDataPipe first?

from typing import Callable, Dict, Iterable, Optional, TypeVar

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe

from torch.utils.data.datapipes.utils.common import check_lambda_fn

@functional_datapipe("demux")
class DemultiplexerMapDataPipe:
    def __new__(cls, datapipe: MapDataPipe, num_instances: int, classifier_fn: Callable, drop_none: bool = False,
                source_index: Optional[Iterable] = None):
        if num_instances < 1:
            raise ValueError(f"Expected `num_instances` larger than 0, but {num_instances} is found")
        check_lambda_fn(classifier_fn)
        container = _DemultiplexerMapDataPipe(datapipe, num_instances, classifier_fn, drop_none, source_index)
        return [_DemultiplexerChildMapDataPipe(container, i) for i in range(num_instances)]

class _DemultiplexerMapDataPipe:
    def __init__(
        self,
        datapipe: MapDataPipe[T_co],
        num_instances: int,
        classifier_fn: Callable[[T_co], Optional[int]],
        drop_none: bool,
        source_index: Optional[Iterable],
    ):
        self.main_datapipe = datapipe
        self.num_instances = num_instances
        self.classifier_fn = classifier_fn
        self.drop_none = drop_none
        self.iterator = None
        self.exhausted = False  # Once we iterate through `main_datapipe` once, we know all the index mapping
        self.index_mapping = [[] for _ in range(num_instances)]
        self.source_index = source_index  # if None, assume `main_datapipe` 0-index

    def _classify_next(self):
        if self.source_index is None:
            self.source_index = range(len(self.main_datapipe))
        if self.iterator is None:
            self.iterator = iter(self.source_index)
        try:
            next_source_idx = next(self.iterator)
        except StopIteration:
            self.exhausted = True
            return
        value = self.main_datapipe[next_source_idx]
        classification = self.classifier_fn(value)
        if classification is None and self.drop_none:
            self._classify_next()
        else:
            self.index_mapping[classification].append(value)

    def classify_all(self):
        while not self.exhausted:
            self._classify_next()

    def get_value(self, instance_id: int, index: int) -> T_co:
        while not self.exhausted and len(self.index_mapping[instance_id]) <= index:
            self._classify_next()
        if len(self.index_mapping[instance_id]) > index:
            return self.index_mapping[instance_id][index]
        raise RuntimeError("Index is out of bound.")

    def __len__(self):
        return len(self.main_datapipe)

class _DemultiplexerChildMapDataPipe(MapDataPipe):
    def __init__(self, main_datapipe: _DemultiplexerMapDataPipe, instance_id: int):
        self.main_datapipe: _DemultiplexerMapDataPipe = main_datapipe
        self.instance_id = instance_id

    def __getitem__(self, index: int):
        return self.main_datapipe.get_value(self.instance_id, index)

    def __len__(self):
        self.main_datapipe.classify_all()  # You have to read through the entirety of main_datapipe to know `len`
        return len(self.main_datapipe.index_mapping[self.instance_id])

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]
VitalyFedyunin commented 2 years ago

Any reason to keep datapipes as MapDataPipe? You can always convert them into IterDataPipe, use standard demux and avoid all keep results in memory issues.

The main reason why we decided to limit MapDataPipes functionality is to encourage users to convert to IterDataPipes as fast as possible.

josiahls commented 2 years ago

@VitalyFedyunin As an outside user, I would expect basic torchdata to allow me to build a simple data pipe for something popular like MNIST which is going to be Map based. I would expect a simple map based dataset to only require MapDataPipes. Having to switch to and from Iter pipes to train on a "simple" map indexed dataset doesn't seem intuitive to me even if using the Iter mux/demux would be viable. Maybe thats just me though.

If you would like, I'll try making a simple mnist data pipe using what you recommend, and then make an example using on MapDataPipes so we can see how they compare.

josiahls commented 2 years ago

@NivekT I've been busy this week, but should be able to spend this weekend on this. I do like the implimenations you put, they look way cleaner than mine lol. This weekend I'll try to poke some holes into your implementation with some simpler unit tests (not the complex mnist I've been using).

"I am starting to think it might be desirable to have a standard way to retrieve indices 
from a MapDataPipe (maybe self.indices)"

That's what I was thinking. I kept on wanting to do .keys() or something like that. I can make some attempts on making a common index either it is a field/property/function.

I'll edit this with a better opinion on what you posted early this weekend.