Open lendle opened 1 year ago
replicable
means the DataPipe
can be copied multiple times for multiprocessing workers. If it's not, it will be either kept in a dispatching process when ShardingRoundRobinDispatcher
is used or kept in the main process at the end connected to all replicated DataPipes from each worker process
I agree the docs for ShardingFilter
and ShardingRoundRobinDispatcher
are confusing.
I created code (see further below for the code and example output) to test my understanding of ShardingRoundRobinDispatcher
. Based on the ShardingRoundRobinDispatcher docs, I expected that
dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)
would result in increment
running on each worker process, but that
dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
would result in increment
running on a single dispatching process totally separate from the DataLoader
worker processes. But as you can see in the output below, increment
is still being called across multiple processes.
Code:
import os
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
def increment(x):
print(f"processs ID {os.getpid()} called increment for {x+1}")
return x + 1
def create_datapipe_round_robin_before_increment(i):
dp = IterableWrapper(range(i))
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)
return dp
def create_datapipe_round_robin_after_increment(i):
dp = IterableWrapper(range(i))
dp = dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
return dp
if __name__ == "__main__":
print(f"parent process ID: {os.getpid()}")
N = 5
print("sharding_round_robin_dispatch BEFORE map(increment):")
dp1 = create_datapipe_round_robin_before_increment(N)
for data in DataLoader(dp1, num_workers=2):
print(int(data))
print()
print("sharding_round_robin_dispatch AFTER map(increment):")
dp2 = create_datapipe_round_robin_after_increment(N)
for data in DataLoader(dp2, num_workers=2):
print(int(data))
Output with torchdata 0.6.1 on MacOS Ventura 13.3.1 (Intel):
$ python shard_round_robin_test.py
parent process ID: 13495
sharding_round_robin_dispatch BEFORE map(increment):
processs ID 13502 called increment for 1
processs ID 13503 called increment for 1
processs ID 13503 called increment for 2
processs ID 13502 called increment for 2
1
processs ID 13502 called increment for 3
1
processs ID 13503 called increment for 3
2
processs ID 13502 called increment for 4
2
processs ID 13503 called increment for 4
3
processs ID 13502 called increment for 5
3
processs ID 13503 called increment for 5
4
4
5
5
sharding_round_robin_dispatch AFTER map(increment):
processs ID 13510 called increment for 1
processs ID 13511 called increment for 1
processs ID 13510 called increment for 2
processs ID 13511 called increment for 2
1
processs ID 13510 called increment for 3
1
processs ID 13511 called increment for 3
2
processs ID 13510 called increment for 4
2
processs ID 13511 called increment for 4
3
processs ID 13510 called increment for 5
3
processs ID 13511 called increment for 5
4
4
5
5
What's the expected behavior here?
@JohnHBrock I think you need to be using torchdata's DataLoader2, not DataLoader.
@lendle You're right, I just tested and it works with DataLoader2
. Here's the DataLoader2
version of the above code for comparison:
import os
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
def increment(x):
print(f"processs ID {os.getpid()} called increment for {x+1}")
return x + 1
def create_datapipe_round_robin_before_increment(i):
dp = IterableWrapper(range(i))
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)
return dp
def create_datapipe_round_robin_after_increment(i):
dp = IterableWrapper(range(i))
dp = dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
return dp
if __name__ == "__main__":
print(f"parent process ID: {os.getpid()}")
N = 5
print("sharding_round_robin_dispatch BEFORE map(increment):")
dp1 = create_datapipe_round_robin_before_increment(N)
mp_reading_service = MultiProcessingReadingService(num_workers=2)
for data in DataLoader2(dp1, reading_service=mp_reading_service):
print(int(data))
print()
print("sharding_round_robin_dispatch AFTER map(increment):")
dp2 = create_datapipe_round_robin_after_increment(N)
mp_reading_service = MultiProcessingReadingService(num_workers=2)
for data in DataLoader2(dp2, reading_service=mp_reading_service):
print(int(data))
and here's the output I see:
parent process ID: 88637
sharding_round_robin_dispatch BEFORE map(increment):
processs ID 88646 called increment for 2
processs ID 88645 called increment for 1
processs ID 88646 called increment for 4
1
processs ID 88645 called increment for 3
2
processs ID 88645 called increment for 5
3
4
5
sharding_round_robin_dispatch AFTER map(increment):
processs ID 88650 called increment for 1
processs ID 88650 called increment for 2
processs ID 88650 called increment for 3
processs ID 88650 called increment for 4
1
2
processs ID 88650 called increment for 5
3
4
5
I had initially posted that this didn't work with DataLoader2
either, but I realized there was a bug in my code.
📚 The doc issue
In the ReadingService docs the different sharding options and that one applies to replicable and one to non-replicable datapipes, but it's not really explained what that means.
Indirectly related, I'm also confused by the names
ShardingRoundRobinDispatcher
andShardingFilter
. The docs forShardingFilter
sayIs that not essentially the definition of round robin distribution? How is that different than what the the DataPipes downstream of a
ShardingRoundRobinDispatcher
on different workers receive?Suggest a potential alternative/fix
Clarify more the difference between
ShardingRoundRobinDispatcher
andShardingFilter
and explain what 'replicable' means in that context.Possibly consider renaming
ShardingRoundRobinDispatcher
andShardingFilter
, if the answers to my questions above are 'yes' to something more meaningful.