Open jhoareau opened 1 year ago
I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using
dataloader.datapipe
, then I can easily access the state of my datapipe using the code shown below.
When MP gets involved, the partial DataPipe graph is sent to worker process. So, there won't be any reference of that partial graph from the main process. QueueWrapper
is the place connecting worker process to main process.
As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.
Yes, it is. And, we are working on the solution for it. And, we probably want to add a new request like https://github.com/pytorch/data/blob/a3b34a00e7d2b6694ea0d5e21fcc084080a3abae/torchdata/dataloader2/communication/messages.py#LL89C7-L89C21 to pass the request for state to worker process and let worker process send back the state of graph.
Wondering do you have any specific use cases to access datapipe state on top of checkpointing?
we probably want to add a new request like https://github.com/pytorch/data/blob/a3b34a00e7d2b6694ea0d5e21fcc084080a3abae/torchdata/dataloader2/communication/messages.py#LL89C7-L89C21 to pass the request for state to worker process and let worker process send back the state of graph.
This is what I had envisioned as well. Glad to hear it's being worked on. Would you accept a PR adding this functionality?
Our specific use case is for a data loading progress bar, but instead of counting after sharding, we want to count batch sizes before sharding (that's because we can have training on multiple ranks, and we want to avoid multi-rank synchronisation, so we want to see where the rank 0 datapipe is currently pre-sharding).
Our datapipe is like so: FileOpener -> LineReader -> Map (tokenization) -> MaxTokenBucketizer -> Shard -> Collate We want to measure the total size of batches produced by MaxTokenBucketizer pre-sharding.
We have a potential workaround by also returning this size with an extra Map before Shard, but we'd prefer not to.
FYI, I have started working on a PR that adds that functionality via the snapshot
function of the ReadingService, as a PoC. I hope it will fit well with your plans for the feature.
Would you accept a PR adding this functionality?
cc: @NivekT as the POC for snapshot/checkpoint. From my perspective, you can open a PR as RFC and Kevin will discuss it on the PR since he has a working solution right now. And, we can see if those solutions are aligned.
Our specific use case is for a data loading progress bar, but instead of counting after sharding, we want to count batch sizes before sharding (that's because we can have training on multiple ranks, and we want to avoid multi-rank synchronisation, so we want to see where the rank 0 datapipe is currently pre-sharding).
Do you mean batch sizes or number of batches?
I mean summed batch sizes. I've now created a PR as a RFC.
@jhoareau I responded here. Let me know if I missed something about your use case. Thanks for opening the issue and PR!
Hi @NivekT, thanks for the detailed reply. I'll keep the conversation about state checkpointing in the PR, and will focus on the specific problem I'm trying to solve in this issue.
The documentation is quite vague on how to use sharding_round_robin_dispatch
and I've gotten odd results with using it (4x the amount of data with 4 workers), would you have any example code on how to replace sharding_filter
with it?
@jhoareau Can you tell us more about the set up where you are seeing duplicate data (what is the data pipeline)?
For example, here is a multiprocessing example (ran with nightly version):
dp1 = IterableWrapper(range(10)).sharding_filter().map(_fn)
dp2 = IterableWrapper(range(10)).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(_fn)
for dp in [dp1, dp2]:
rs = MultiProcessingReadingService(num_workers=2)
dl = DataLoader2(dp, reading_service=rs)
print(list(dl)) # [0, 1, ..., 9] in both cases
If you are using DistributedReadingService
, then you will want to place a .sharding_filter()
prior to .sharding_round_robin_dispatch()
in order to divide up the work among nodes first.
Let us know if this is unclear.
Hi @NivekT it works with the sharding filter before the sharding round robin, indeed we're running multiprocessing + distributed. Thanks for the pointer. However, I needed to monkey-patch the round_robin_demux
to set the buffer size to -1 (unlimited) for our use case (we collect 50k samples before building batches, so the buffer size of 1000 does not work for us).
I still see value in extracting state from the underlying datapipes with the MPReadingService, so I'll leave my PR up and hoping that we can also discuss that separately.
Hi TorchData team,
I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using
dataloader.datapipe
, then I can easily access the state of my datapipe using the code shown below.However, in the multi processing case, the datapipe graph is replaced with QueueWrapper instances, and I cannot find any way to communicate with the workers to get access to the state of the data pipe (and I get the error that my StatefulIterator cannot be found on the datapipe). If I access
dl2._datapipe_before_reading_service_adapt
I do get the initial state only which makes sense since there is no state sync between the main and worker processes.As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.
Potentially, could we add a
getstate
communication primitive incommunication.messages
in order to capture the state (via getstate) of a datapipe in a worker process? We're also open to usingsharding_round_robin_dispatch
in order to keep more information in the main process but I'm a bit confused on how to use it, if you have some sample code for me for the following case?Running against today's master (commit a3b34a00e7d2b6694ea0d5e21fcc084080a3abae):
cc: @ejguan @VitalyFedyunin @NivekT