pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.09k stars 142 forks source link

DataLoader2 Memory Behavior is very strange on Epoch Resets #1185

Open andrew-bydlon opened 1 year ago

andrew-bydlon commented 1 year ago

🐛 Describe the bug

Memory increase at the start of iteration after the start

I have been trying to use DataLoader2 with multiprocessing (and distributed in some cases). In general, its behavior is pretty strange relative to the original data loader implementation (which I'll call DataLoader1 below). It seems that after the completion of an epoch (iteration) the dataloader holds all data states instead of resetting. As a result, memory usage increases from the train epoch to the validation epoch.

More problematic still; when starting the next epoch the previous epochs states seem to be held and cause memory usage to spike upwards. I imagine this causes (some of the many recent issues) Memory Errors, and did for sure in my case when training with DDP. DataLoader1 has none of these issues.

I tested with a relatively complicated datapipe, using Multiplexing, several intermediate 1:Multi yielding mechanisms, and producing a pair (audio: tensor, metadata: dict).

DataSamplerPipes

I saw a recent post claiming that dictionaries were the issue. At least from what I have seen it is the reading service more than dictionaries.

DataLoader2 compared with torch.data DataLoader1

Here is the code that I used to produce the results below.

def BenchmarkLoading(Pipe, N=10000, NumPrints=10, phrase='Decoding Tars, DataPipe'):
    bm = Benchmarker(f'{phrase}, {N} samples')
    for i,x in enumerate(Pipe):
        if i%N == 0 and i:
            print('Iteration', i, 'of', N*NumPrints)
            bm()
        if i==N*NumPrints:
            break
    bm.compute()

batch_size = 1024 # Fixed for testing.
num_workers = 8
rs = reader.MultiProcessingReadingService(num_workers=num_workers)
dataloader = DataLoader2(Collator(datapipe.batch(batch_size)), reading_service=rs)
# dataloader = torch.utils.data.DataLoader(datapipe, num_workers=num_workers, batch_size=batch_size)

for i in range(100):
    print(i)
    dataloader.seed(i)
    BenchmarkLoading(dataloader, N=100, NumPrints=100, phrase=f'{task_name}, Data loader (not 2), batches {batch_size}, workers {num_workers}')
    print('\n\n')

Results:

# DL1: Very consistent behavior.

"Epoch 0"
Iteration 100 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 54.31092572212219 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 12.599214553833008 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 13.167328834533691 seconds.

"Epoch 1"
Iteration 100 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 54.44526529312134 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 11.10575008392334 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 16.893246173858643 seconds.

"Epoch 2"
Iteration 100 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 54.24127960205078 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 10.061070203781128 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.461547136306763 seconds.

Memory usage graph: TorchDataLoader1

DL2: Very consistently poor performance at the start after "Epoch 0"

"Epoch 0"
Iteration 100 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 55.075947523117065 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.734715938568115 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.988599061965942 seconds.

"Epoch 1"
Iteration 100 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 144.52063298225403 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 8.694239616394043 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 11.868495464324951 seconds.

"Epoch 2"
Iteration 100 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 144.5120747089386 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.244807004928589 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 12.066998481750488 seconds.

"Epoch 3"
Iteration 100 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 147.51504135131836 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.539249181747437 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.461547136306763 seconds.

Memory Usage: TorchDataLoader2

Attempts to get the resetting behavior of DL1

I studied the internal variable states embedded in DataLoader2 and the reading service.

In the reading service, there are pipes that stick around (per worker) after the epoch.

 {'num_workers': 8,
 'multiprocessing_context': None,
 'worker_prefetch_cnt': 10,
 'main_prefetch_cnt': 10,
 'worker_init_fn': None,
 'worker_reset_fn': None,
 '_worker_processes': [(<ForkProcess name='ForkProcess-9' pid=27584 parent=23977 started daemon>,
   <multiprocessing.queues.Queue at 0x7f8980c57b20>,
   <multiprocessing.queues.Queue at 0x7f8971607940>),
  (<ForkProcess name='ForkProcess-10' pid=27585 parent=23977 started daemon>,
   <multiprocessing.queues.Queue at 0x7f8980c57df0>,
   <multiprocessing.queues.Queue at 0x7f8971607e80>),
  (<ForkProcess name='ForkProcess-11' pid=27617 parent=23977 started daemon>,
   <multiprocessing.queues.Queue at 0x7f89716079a0>,
   <multiprocessing.queues.Queue at 0x7f8971680490>),
  (<ForkProcess name='ForkProcess-12' pid=27649 parent=23977 started daemon>,
   <multiprocessing.queues.Queue at 0x7f89716800d0>,
   <multiprocessing.queues.Queue at 0x7f8971680a60>),
  (<ForkProcess name='ForkProcess-13' pid=27669 parent=23977 started daemon>,
   <multiprocessing.queues.Queue at 0x7f8971680670>,
   <multiprocessing.queues.Queue at 0x7f8971681030>),
  (<ForkProcess name='ForkProcess-14' pid=27682 parent=23977 started daemon>,
   <multiprocessing.queues.Queue at 0x7f8971680c40>,
   <multiprocessing.queues.Queue at 0x7f89716814b0>),
  (<ForkProcess name='ForkProcess-15' pid=27745 parent=23977 started daemon>,
   <multiprocessing.queues.Queue at 0x7f8971681120>,
   <multiprocessing.queues.Queue at 0x7f8971681a80>),
  (<ForkProcess name='ForkProcess-16' pid=27777 parent=23977 started daemon>,
   <multiprocessing.queues.Queue at 0x7f8971681690>,
   <multiprocessing.queues.Queue at 0x7f8971682050>)],
 '_dispatch_process': None,
 '_worker_datapipes': [QueueWrapper,
  QueueWrapper,
  QueueWrapper,
  QueueWrapper,
  QueueWrapper,
  QueueWrapper,
  QueueWrapper,
  QueueWrapper],
 '_worker_consumer_datapipe': _IterateQueueDataPipes,
 '_main_prefetch_datapipe': PrefetcherIterDataPipe,
 '_end_datapipe': PrefetcherIterDataPipe,
 '_mp': True}

I was able to effectively resolve the early startup time by resetting all of these values to their original values. However, this resulted in the creation of a whole new dataloader and doubled the memory usage (attached image).

dataloader.reading_service._worker_processes = []
dataloader.reading_service._worker_datapipes = []
dataloader.reading_service._worker_consumer_datapipe = None
dataloader.reading_service._main_prefetch_datapipe = None
dataloader.reading_service._end_datapipe = None
dataloader.datapipe = dataloader._datapipe_before_reading_service_adapt
dataloader._datapipe_iter = None
dataloader.valid_iterator_id = None
dataloader._adapted = False

TorchDataLoader2_manual_resets

Q: Is there a way to embed the reset behavior into the 'worker_reset_fn' variable of the reading service without causing the memory increase?

Other recommendations to hard reset the data loader every step? Compared to DL1, it is much less efficient to keep the memory stored and when resetting to briefly have 2 dataloaders worth of RAM usage. It also causes startup time for my jobs per epoch to nearly triple, before proceeding as normal.

I left my original comment here: https://github.com/pytorch/data/issues/1150

Small comment about datapipes, isolating to the reading service

Datapipe performance is very consistent after resetting the iterator. This may be clear already from DL1 but I ran the test so showing it here: OnlyDatapipeRunning

Versions

EnvInfo.txt

andrew-bydlon commented 1 year ago

@ejguan: Do you have any suggestions for properly resetting Dataloader 2 after each epoch? With e.g. worker_reset_fn.

Adenialzz commented 1 year ago

Hello, I have also encountered a situation where the DL2 memory usage has skyrocketed. I have temporarily decided to switch back to DL1. May I ask how to set up datapipe+DL1 for multi process and multi card training? Do I need to set up distributed sampling in DL1?

andrew-bydlon commented 1 year ago

@Adenialzz: To get what I showed above, it's more or less the same setup as for the a torch dataset. Replace the dataset with a datapipe.

sampler = DistributedSampler(datapipe) if distributed else None
return DataLoader(datapipe, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, batch_size=batch_size)
Adenialzz commented 1 year ago

This DistributedSampler requires my dataset(datapipe) must have len method, but the length of my datapipe cannot be calculated cause it is a iterable datapipe. Have you ever met problem like this?

andrew-bydlon commented 1 year ago

I'll give it a try today.

Adenialzz commented 1 year ago

Thanks, please let me know when you make progress.

andrew-bydlon commented 1 year ago

Sorry for the delay @Adenialzz. You are correct that it doesn't work with DDP and without a length on an iterable data pipe. I reverted to DL2 despite its notably slower performance as it only really occurs at the start of the epoch.

Adenialzz commented 1 year ago

set torch.utils.data.graph_settings.apply_sharding(datapipe, world_size, rank) seems to solve the problem in my case.

keunwoochoi commented 4 months ago

@Adenialzz hi, could i ask you for a clarification? how was it used to fixed which problem exactly? i'd appreciate it very much.