pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.13k stars 152 forks source link

MPRS: Keeping references to datapipes in wrappers can cause sequentialization of pipes #1112

Open sehoffmann opened 1 year ago

sehoffmann commented 1 year ago

🐛 Describe the bug

This took me extremely long to figure out and is a super sneaky bug:

import os
import functools
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from torchdata.datapipes import functional_datapipe
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

@functional_datapipe('non_replicable')
class NonReplicableIterDataPipe(IterDataPipe):
    def __init__(self, dp):
        self.dp = dp

    def __iter__(self):
        return iter(self.dp)

    def is_replicable(self):
        return False

def init_fn(datapipe, worker_info):
    print('Hello from init_fn: ', os.getpid(), flush=True)
    return datapipe

def test(x, prefix=''):
    print(prefix, ' ', x, ' ', os.getpid(), flush=True)
    return x

def create_pipe(src):
    pipe = src
    pipe = pipe.map(functools.partial(test, prefix='Pre Main'))
    pipe = pipe.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
    pipe = pipe.map(functools.partial(test, prefix='Worker'))
    pipe = pipe.non_replicable()
    pipe = pipe.map(functools.partial(test, prefix='Post Main'))
    return pipe

class Wrapper(IterDataPipe):

    def __init__(self):
        self.src = IterableWrapper(range(4))
        self.dp = create_pipe(self.src)

    def __iter__(self):
        return iter(self.dp)

def main():
    print('Main Process: ', os.getpid(), flush=True)
    pipe = create_pipe(IterableWrapper(range(4)))

    rs = MultiProcessingReadingService(2, worker_init_fn=init_fn, worker_prefetch_cnt=0, main_prefetch_cnt=0)
    loader = DataLoader2(pipe, reading_service=rs)
    [_ for _ in loader]

    print('-------------------------------------')

    rs = MultiProcessingReadingService(2, worker_init_fn=init_fn, worker_prefetch_cnt=0, main_prefetch_cnt=0)
    loader = DataLoader2(Wrapper(), reading_service=rs)
    [_ for _ in loader]

if __name__ == '__main__':
    main()

This pipe wants to do the following:

  1. Preprocess data in main process up to sharding_round_robin_dispatch
  2. Process data in worker processes
  3. Postprocess data in main process, from non_replicable() onwards

Output:

Main Process:  22744
Hello from init_fn:  22809
Hello from init_fn:  22746
Pre Main   0   22745
Pre Main   1   22745
Worker   0   22746
Worker   1   22809
Pre Main   2   22745
Pre Main   3   22745
Post Main   0   22744
Worker   2   22746
Post Main   1   22744
Worker   3   22809
Post Main   2   22744
Post Main   3   22744
-------------------------------------
Hello from init_fn:  22944
Hello from init_fn:  22968
Pre Main   0   22943
Worker   0   22943
Post Main   0   22943
Pre Main   1   22943
Worker   1   22943
Post Main   1   22943
Pre Main   2   22943
Worker   2   22943
Post Main   2   22943
Pre Main   3   22943
Worker   3   22943
Post Main   3   22943

As you can see, the first case works as expected, but in the second case the pipeline is completely sequentialized. In fact, it fully runs in the dispatch process of the MPRS.

Now, after having finally figured our the source of this bug, I believe that I understand why it behaves as it does. Keeping the reference to src causes the whole graph and pipeline to become non-replicable.

I'm aware of the technical limitations, i.e. that torchdata uses pickle and object attributes to figure out the graph, and does in fact not take into consideration what part of it is actually used in __iter__. However, at least from my perspective, this is a very severe issue. This behavior is extremely unexpected and its very easy to unconsciously cause it. As you can take from my description so far, it also took me a considerable time to even figure out the source of it. Moreover, it's hard to even recognize in the first place; in my specific case i only noticed it because a serialization operation failed due to a tensor being on the GPU (an operation I do at the very end of my pipeline). Should I have not activated and tested that particular part of the pipeline, this issue likely would have eluded me completely until I did some benchmark which would have revealed vastly reduced performance.

To make it short: I believe this definitely needs addressing in one form or another. The least one has to do is to put a very big disclaimer into the docs. Of course, it would be even better if this issue would not occur in the first place without any intervention of the user. Should this not be possible, we definitely need a wrapper class a kin to a "WeakReference" that allows us to mark an IterDataPipe attribute of an object as "not part of the graph". I.e. something like:

class Wrapper(IterDataPipe):

    def __init__(self):
        self.src = NoGraph(IterableWrapper(range(4)))
        self.dp = create_pipe(self.src)

    def __iter__(self):
        return iter(self.dp)

What I really find concerning is that this can easily go unnoticed. Due to this, I would even suggest considering adding a get_parents() function that must explicitely list which pipes are parents. The behavior implemented by torchdata could then look like this:

  1. In order to still support the most common case: If there's only one reference to another DataPipe, by default, use this as parent. I.e. the same behavior that is currently in place.
  2. If there are more than one reference and a get_parents() function, use the output of the get_parents() function
  3. If there are more than one reference, but no get_parents() function, throw an error.

Versions

https://github.com/pytorch/data/commit/e78ab6c9ec94f05f0a350ced7fe571f6863c20ec

ejguan commented 1 year ago

Here are a few things in my mind to help users easily find this problem: