pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.13k stars 152 forks source link

What does it mean for a DataPipe to be 'replicable'? #1131

Open lendle opened 1 year ago

lendle commented 1 year ago

📚 The doc issue

In the ReadingService docs the different sharding options and that one applies to replicable and one to non-replicable datapipes, but it's not really explained what that means.

Indirectly related, I'm also confused by the names ShardingRoundRobinDispatcher and ShardingFilter. The docs for ShardingFilter say

each instance of the DataPipe (on different workers) will have every n-th element of the original DataPipe, where n equals to the number of instances.

Is that not essentially the definition of round robin distribution? How is that different than what the the DataPipes downstream of a ShardingRoundRobinDispatcher on different workers receive?

Suggest a potential alternative/fix

Clarify more the difference between ShardingRoundRobinDispatcher and ShardingFilter and explain what 'replicable' means in that context.

Possibly consider renaming ShardingRoundRobinDispatcher and ShardingFilter, if the answers to my questions above are 'yes' to something more meaningful.

ejguan commented 1 year ago

replicable means the DataPipe can be copied multiple times for multiprocessing workers. If it's not, it will be either kept in a dispatching process when ShardingRoundRobinDispatcher is used or kept in the main process at the end connected to all replicated DataPipes from each worker process

JohnHBrock commented 1 year ago

I agree the docs for ShardingFilter and ShardingRoundRobinDispatcher are confusing.

I created code (see further below for the code and example output) to test my understanding of ShardingRoundRobinDispatcher. Based on the ShardingRoundRobinDispatcher docs, I expected that

dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)

would result in increment running on each worker process, but that

dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)

would result in increment running on a single dispatching process totally separate from the DataLoader worker processes. But as you can see in the output below, increment is still being called across multiple processes.

Code:

import os
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

def increment(x):
    print(f"processs ID {os.getpid()} called increment for {x+1}")
    return x + 1

def create_datapipe_round_robin_before_increment(i):
    dp = IterableWrapper(range(i))
    dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)
    return dp

def create_datapipe_round_robin_after_increment(i):
    dp = IterableWrapper(range(i))
    dp = dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
    return dp

if __name__ == "__main__":
    print(f"parent process ID: {os.getpid()}")
    N = 5

    print("sharding_round_robin_dispatch BEFORE map(increment):")
    dp1 = create_datapipe_round_robin_before_increment(N)
    for data in DataLoader(dp1, num_workers=2):
        print(int(data))
    print()
    print("sharding_round_robin_dispatch AFTER map(increment):")
    dp2 = create_datapipe_round_robin_after_increment(N)
    for data in DataLoader(dp2, num_workers=2):
        print(int(data))

Output with torchdata 0.6.1 on MacOS Ventura 13.3.1 (Intel):

$ python shard_round_robin_test.py
parent process ID: 13495
sharding_round_robin_dispatch BEFORE map(increment):
processs ID 13502 called increment for 1
processs ID 13503 called increment for 1
processs ID 13503 called increment for 2
processs ID 13502 called increment for 2
1
processs ID 13502 called increment for 3
1
processs ID 13503 called increment for 3
2
processs ID 13502 called increment for 4
2
processs ID 13503 called increment for 4
3
processs ID 13502 called increment for 5
3
processs ID 13503 called increment for 5
4
4
5
5

sharding_round_robin_dispatch AFTER map(increment):
processs ID 13510 called increment for 1
processs ID 13511 called increment for 1
processs ID 13510 called increment for 2
processs ID 13511 called increment for 2
1
processs ID 13510 called increment for 3
1
processs ID 13511 called increment for 3
2
processs ID 13510 called increment for 4
2
processs ID 13511 called increment for 4
3
processs ID 13510 called increment for 5
3
processs ID 13511 called increment for 5
4
4
5
5

What's the expected behavior here?

lendle commented 1 year ago

@JohnHBrock I think you need to be using torchdata's DataLoader2, not DataLoader.

JohnHBrock commented 1 year ago

@lendle You're right, I just tested and it works with DataLoader2. Here's the DataLoader2 version of the above code for comparison:

import os
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

def increment(x):
    print(f"processs ID {os.getpid()} called increment for {x+1}")
    return x + 1

def create_datapipe_round_robin_before_increment(i):
    dp = IterableWrapper(range(i))
    dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)
    return dp

def create_datapipe_round_robin_after_increment(i):
    dp = IterableWrapper(range(i))
    dp = dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
    return dp

if __name__ == "__main__":
    print(f"parent process ID: {os.getpid()}")
    N = 5

    print("sharding_round_robin_dispatch BEFORE map(increment):")
    dp1 = create_datapipe_round_robin_before_increment(N)
    mp_reading_service = MultiProcessingReadingService(num_workers=2)
    for data in DataLoader2(dp1, reading_service=mp_reading_service):
        print(int(data))
    print()
    print("sharding_round_robin_dispatch AFTER map(increment):")
    dp2 = create_datapipe_round_robin_after_increment(N)
    mp_reading_service = MultiProcessingReadingService(num_workers=2)
    for data in DataLoader2(dp2, reading_service=mp_reading_service):
        print(int(data))

and here's the output I see:

parent process ID: 88637
sharding_round_robin_dispatch BEFORE map(increment):
processs ID 88646 called increment for 2
processs ID 88645 called increment for 1
processs ID 88646 called increment for 4
1
processs ID 88645 called increment for 3
2
processs ID 88645 called increment for 5
3
4
5

sharding_round_robin_dispatch AFTER map(increment):
processs ID 88650 called increment for 1
processs ID 88650 called increment for 2
processs ID 88650 called increment for 3
processs ID 88650 called increment for 4
1
2
processs ID 88650 called increment for 5
3
4
5

I had initially posted that this didn't work with DataLoader2 either, but I realized there was a bug in my code.