pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.13k stars 152 forks source link

Accessing DataPipe state with MultiProcessingReadingService #1033

Open jhoareau opened 1 year ago

jhoareau commented 1 year ago

Hi TorchData team,

I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using dataloader.datapipe, then I can easily access the state of my datapipe using the code shown below.

However, in the multi processing case, the datapipe graph is replaced with QueueWrapper instances, and I cannot find any way to communicate with the workers to get access to the state of the data pipe (and I get the error that my StatefulIterator cannot be found on the datapipe). If I access dl2._datapipe_before_reading_service_adapt I do get the initial state only which makes sense since there is no state sync between the main and worker processes.

As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.

Potentially, could we add a getstate communication primitive in communication.messages in order to capture the state (via getstate) of a datapipe in a worker process? We're also open to using sharding_round_robin_dispatch in order to keep more information in the main process but I'm a bit confused on how to use it, if you have some sample code for me for the following case?

Running against today's master (commit a3b34a00e7d2b6694ea0d5e21fcc084080a3abae):

import torchdata.datapipes as dp
from torch.utils.data.graph_settings import get_all_graph_pipes, traverse_dps
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService

class StatefulIterator(dp.iter.IterDataPipe):
    def __init__(self, datapipe):
        self.datapipe = datapipe
        self.custom_index = 0

    def __iter__(self):
        self.custom_index = 0
        for item in self.datapipe:
            self.custom_index += 1
            yield item
        self.custom_index = 0

def get_datapipe():
    initial_data = dp.iter.IterableWrapper([1, 2, 3, 4])
    stateful_data = StatefulIterator(initial_data)
    sharded_data = stateful_data.sharding_filter()
    return sharded_data

def get_datapipe_state(datapipe):
    graph = traverse_dps(datapipe)
    all_pipes = get_all_graph_pipes(graph)
    for pipe in all_pipes:
        if hasattr(pipe, "custom_index"):
            return pipe.custom_index

    raise ValueError("This datapipe does not contain a StatefulIterator.")

def main_no_multiprocessing():
    dp = get_datapipe()
    dl2 = DataLoader2(dp)
    for item in dl2:
        print("Custom index", get_datapipe_state(dl2.datapipe))
        print("Item", item)

def main_multiprocessing():
    dp = get_datapipe()
    dl2 = DataLoader2(dp, reading_service=MultiProcessingReadingService(num_workers=4))
    for item in dl2:
        print("Custom index", get_datapipe_state(dl2.datapipe))
        print("Item", item)

if __name__ == "__main__":
    main_no_multiprocessing()
    main_multiprocessing()

cc: @ejguan @VitalyFedyunin @NivekT

ejguan commented 1 year ago

I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using dataloader.datapipe, then I can easily access the state of my datapipe using the code shown below.

When MP gets involved, the partial DataPipe graph is sent to worker process. So, there won't be any reference of that partial graph from the main process. QueueWrapper is the place connecting worker process to main process.

As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.

Yes, it is. And, we are working on the solution for it. And, we probably want to add a new request like https://github.com/pytorch/data/blob/a3b34a00e7d2b6694ea0d5e21fcc084080a3abae/torchdata/dataloader2/communication/messages.py#LL89C7-L89C21 to pass the request for state to worker process and let worker process send back the state of graph.

Wondering do you have any specific use cases to access datapipe state on top of checkpointing?

jhoareau commented 1 year ago

we probably want to add a new request like https://github.com/pytorch/data/blob/a3b34a00e7d2b6694ea0d5e21fcc084080a3abae/torchdata/dataloader2/communication/messages.py#LL89C7-L89C21 to pass the request for state to worker process and let worker process send back the state of graph.

This is what I had envisioned as well. Glad to hear it's being worked on. Would you accept a PR adding this functionality?

Our specific use case is for a data loading progress bar, but instead of counting after sharding, we want to count batch sizes before sharding (that's because we can have training on multiple ranks, and we want to avoid multi-rank synchronisation, so we want to see where the rank 0 datapipe is currently pre-sharding).

Our datapipe is like so: FileOpener -> LineReader -> Map (tokenization) -> MaxTokenBucketizer -> Shard -> Collate We want to measure the total size of batches produced by MaxTokenBucketizer pre-sharding.

We have a potential workaround by also returning this size with an extra Map before Shard, but we'd prefer not to.

jhoareau commented 1 year ago

FYI, I have started working on a PR that adds that functionality via the snapshot function of the ReadingService, as a PoC. I hope it will fit well with your plans for the feature.

ejguan commented 1 year ago

Would you accept a PR adding this functionality?

cc: @NivekT as the POC for snapshot/checkpoint. From my perspective, you can open a PR as RFC and Kevin will discuss it on the PR since he has a working solution right now. And, we can see if those solutions are aligned.

Our specific use case is for a data loading progress bar, but instead of counting after sharding, we want to count batch sizes before sharding (that's because we can have training on multiple ranks, and we want to avoid multi-rank synchronisation, so we want to see where the rank 0 datapipe is currently pre-sharding).

Do you mean batch sizes or number of batches?

jhoareau commented 1 year ago

I mean summed batch sizes. I've now created a PR as a RFC.

NivekT commented 1 year ago

@jhoareau I responded here. Let me know if I missed something about your use case. Thanks for opening the issue and PR!

jhoareau commented 1 year ago

Hi @NivekT, thanks for the detailed reply. I'll keep the conversation about state checkpointing in the PR, and will focus on the specific problem I'm trying to solve in this issue.

The documentation is quite vague on how to use sharding_round_robin_dispatch and I've gotten odd results with using it (4x the amount of data with 4 workers), would you have any example code on how to replace sharding_filter with it?

NivekT commented 1 year ago

@jhoareau Can you tell us more about the set up where you are seeing duplicate data (what is the data pipeline)?

For example, here is a multiprocessing example (ran with nightly version):

dp1 = IterableWrapper(range(10)).sharding_filter().map(_fn)
dp2 = IterableWrapper(range(10)).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(_fn)

for dp in [dp1, dp2]:
    rs = MultiProcessingReadingService(num_workers=2)
    dl = DataLoader2(dp, reading_service=rs)
    print(list(dl))  # [0, 1, ..., 9] in both cases

If you are using DistributedReadingService, then you will want to place a .sharding_filter() prior to .sharding_round_robin_dispatch() in order to divide up the work among nodes first.

Let us know if this is unclear.

jhoareau commented 1 year ago

Hi @NivekT it works with the sharding filter before the sharding round robin, indeed we're running multiprocessing + distributed. Thanks for the pointer. However, I needed to monkey-patch the round_robin_demux to set the buffer size to -1 (unlimited) for our use case (we collect 50k samples before building batches, so the buffer size of 1000 does not work for us).

I still see value in extracting state from the underlying datapipes with the MPReadingService, so I'll leave my PR up and hoping that we can also discuss that separately.