pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.12k stars 149 forks source link

Make accessing WorkerInfo from within a DataPipe more convenient #1084

Open sehoffmann opened 1 year ago

sehoffmann commented 1 year ago

🚀 The feature

import torchdata.datapipes as dp
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
from torchdata.dataloader2 import MultiProcessingReadingService, DataLoader2

my_worker_info = None

def abc(x):
    return x * my_worker_info.worker_id

def worker_init(dp, worker_info):
    global my_worker_info 
    my_worker_info = worker_info
    return dp

pipe = dp.iter.IterableWrapper(range(10))
pipe = pipe.map(abc)
pipe = pipe.sharding_filter(SHARDING_PRIORITIES.MULTIPROCESSING)

rs = MultiProcessingReadingService(num_workers=2, worker_init_fn=worker_init)
dl = DataLoader2(pipe, reading_service=rs)

for x in dl:
    print(x)

Output:

0
1
0
3
0
5
0
7
0
9

This seems to be the only way to my knowledge to access the WorkerInfo from within a DataPipe when using Dataloader2. Global state is obviously awkward and becomes a problem for larger coebases that aren't toy examples. It would be good if there was a more convenient way (and also uniform way wrt Dataloader) a kin to get_worker_info.

Traversing the graph and calling set_worker_info if available would be a good option for this IMO.

Motivation, pitch

I want to easily access the current WorkerInfo from my datapipe.

Alternatives

No response

Additional context

No response

ejguan commented 1 year ago

So, I guess you want a DataPipe behaves differently based WorkerInfo. I think adding get_worker_info is a good feature request.

However, set_worker_info to each DataPipe might be too much as not all of DataPipe would need it and it requires a registry on DataPipe.

sehoffmann commented 1 year ago

Hey @ejguan

I think I would be fine with either. get_worker_info() however would have global state(?) and would produce issues when multiple independent datapipes are iterated in parallel (i know, a bit hypothetical, just saying though)

However, set_worker_info to each DataPipe might be too much as not all of DataPipe would need it and it requires a registry on DataPipe.

No, I don't think so. It would be in line with how sharding and shuffling work at the moment. I.e. one just needs to do something similar to:


def apply_worker_info(datapipe, worker_info):
    graph = traverse_dps(datapipe)
    all_pipes = get_all_graph_pipes(graph)
    for pipe in all_pipes if hasattr(pipe, 'set_worker_info'):
         pipe.set_worker_info(worker_info)
    return datapipe
ejguan commented 1 year ago

get_worker_info() however would have global state(?) and would produce issues when multiple independent datapipes are iterated in parallel (i know, a bit hypothetical, just saying though)

Since they are running on a separate subprocesses, it should be fine.

No, I don't think so. It would be in line with how sharding and shuffling work at the moment. I.e. one just needs to do something similar to:

I see what you mean. So, you want some custom DataPipe to accept it. I have concern on it a little bit how to provide this information to the DataPipe in the dispatching_prcoess or in the main process.

sehoffmann commented 1 year ago

Isn't the worker information only relevant when using the MPRS, DistributedReadingService, or both? I don't see how it is technical any different from e.g. sharding information.

Also, one thing to keep in mind with all these interfaces (including sharding and shuffling), is that people also need to be able to set them easily in their own ReadingService's. For instance, I'm rolling my own HorovodReadingService.

On a side note: Is there interests for a PR for the HorovodReadingService?

ejguan commented 1 year ago

Isn't the worker information only relevant when using the MPRS, DistributedReadingService, or both? I don't see how it is technical any different from e.g. sharding information.

Dispatching process is tied to MPRS as well. And, like we discussed, there might be partial DataPipe remaining in the main process when MPRS gets involved. So, in those cases, we have to ask users/developers to handle if WorkerInfo is not provided.

On a side note: Is there interests for a PR for the HorovodReadingService?

It would be good if you can share more context like a RFC issue.

sehoffmann commented 1 year ago

@ejguan Some specific use case that i would like to handle with this:

pipe = pipe.repeat(N_workers).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)

Here, I would like to introduce a custom operation that doesn't know the number of worker a-priori. A set_worker_info (or global get_worker_info) feature should also take in a sharding priority as argument so that we can specifiy what kind of worker info we are interested in, i.e. distributed (mpi_rank, mpi_size) vs multiprocessing (process_rank, process_count).

ejguan commented 1 year ago

Technical speaking, you can add a Adapter object to DataLoader2 to achieve in-place graph modification, because you should be able to know the value of worker numbers and distributed ranks at initialization time of DataLoader2.

If you want to access the information for specific MP worker, you probably need a get_worker_info function.

sehoffmann commented 1 year ago

Yes, for now I can workaround this. I just wrote this as an example of a real use case and its specific requirements and thought that it might be helpful for you when deciding on a design.