Open sesquipedalianist opened 1 year ago
Have you tried the same pipeline with DataLoader2
and MultiprocessingReadingService
?
I just tried it (I had previously tried DataLoader2 but perhaps with torch 1.13.1) and the spikes still occur. This makes sense to me because it seems the datapipe graph is still traversed in the same way.
I posted another MemoryError that may be related here:
https://discuss.pytorch.org/t/torchdata-w-ddp-start-of-epoch-2-get-memoryerror/179523
My MemoryError also occurs at the start of the epoch while using DDP and distributed multiprocessing. It seems to depend on the size of the shuffles that I put into the datapipe (one for files, one for fixed length decoding, one for augmentations), as I got through 9 epochs before reaching the OOM error most recently.
It's really weird. I use < 150GB of RAM during training and my 500 GB of RAM gets overwhelmed at the beginning of epoch 2. I considered shutting down and restarting the pipe to resolve.
I just tried it (I had previously tried DataLoader2 but perhaps with torch 1.13.1) and the spikes still occur. This makes sense to me because it seems the datapipe graph is still traversed in the same way.
@sesquipedalianist
Thanks for reporting. So, it happens after the first epoch. And, due to in_memory_cache
, the memory usage when traversing through the DataPipes becomes more significant. This might requires some in-depth investigation on how to properly remove the inner object like buffer
, etc during traversing as we only need DataPipe
.
One approach might be adding a wrapper around those objects to prevent them going through pickle
during traverse
.
@andrew-bydlon
It's kind weird to me because shuffler
should not contain any buffer
during reset at the epoch 2, which is the only place that can hold a few data. It means the peak memory consumption should not be expected. It would be great if you can share a minimum reproducible script for us to reproduce and debug.
It's difficult to provide code for this purpose as the code is property of a large corp. Some other notes and expansion of the other thoughts:
Mention of shuffling cause memory increase: https://github.com/pytorch/pytorch/issues/13246#issuecomment-708067670
I am generally storing data in tars in the form (arbitrary length audio, {labels: tensor, dataset: string, ID: string})
And here is an expansion of my list of pipes:
For now I have solved my issue by monkeypatching torch.utils.data.graph_settings.apply_random_seed
as follows:
def apply_random_seed_overwrite(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
return datapipe
torch.utils.data.graph_settings.apply_random_seed = apply_random_seed_overwrite
This effectively disables apply_random_seed
, which for my purposes is not a problem since in training I am not providing a seed and in validation/testing I am not shuffling. Doing this completely eliminates the memory spikes (since we no longer traverse the datapipe at the beginning of each epoch).
@andrew-bydlon are you saving anything in memory (like audio samples?). That would likely cause the same issue as I was having.
I'm not saving anything in memory other than prefetching. I'm using iterable datapipes to do all of the above per recommendations on the homepage. These default to prefetch factors of 10. The augmentation operations take some compute, but all of this happens at the start of epoch 2 (going from 20% memory -> 100%), so it seems extremely unexpected.
Thank you both for your help. I have finally deep-dived this topic and made an issue:
https://github.com/pytorch/data/issues/1185
There is a lot of talk about Memory Leaks in the Issues. I really like the DataLoader2 API, but will be temporarily switching back to DL1 because of the issues that I mention.
For now I have solved my issue by monkeypatching
torch.utils.data.graph_settings.apply_random_seed
as follows:def apply_random_seed_overwrite(datapipe: DataPipe, rng: torch.Generator) -> DataPipe: return datapipe torch.utils.data.graph_settings.apply_random_seed = apply_random_seed_overwrite
This effectively disables
apply_random_seed
, which for my purposes is not a problem since in training I am not providing a seed and in validation/testing I am not shuffling. Doing this completely eliminates the memory spikes (since we no longer traverse the datapipe at the beginning of each epoch).
I tried this out without success. Glad it worked for you!
🐛 Describe the bug
I’ve noticed large “spikes” in memory usage at the start of epochs when using IterDataPipes with attributes that take a lot of memory. These can cause my training jobs to fail with out-of-memory errors.
Here’s a minimal example to reproduce:
The memory usage (logged with psutil) looks like this:
Here,
start_epoch
indicates the start of an epoch andfirst_iter
corresponds to the first time each epoch we reach the pass statement in the dataloader loop. (To simplify the example code above I removed the code that logsstart_epoch
andfirst_iter
. I logged the memory usage from a separate process.)After some debugging, I can say that the memory spikes occur during the traversal of the graph that occurs in
torch/utils/data/graph_settings.py::apply_random_seed()
at the beginning of each epoch. Disabling the body of this function removes the memory spikes.The spikes seem to be caused by the pickling in https://github.com/pytorch/pytorch/blob/99ded8bbcea896b02f1c0babb055329c503ca95e/torch/utils/data/graph.py#L23 The code here defines
f = io.BytesIO()
and pickles tof
. If there are large datapipes to be pickled, it makes sense that the memory usage will blow up quickly and then fall again when f goes out of scope.I tried replacing
f = io.BytesIO()
withf = open(os.devnull, "wb")
(and addingf.close()
at the end of the function). This didn’t eliminate the memory spikes but it did make them a bit smaller.A few notes:
.in_memory_cache()
to see these spikes; it seems that any datapipe that occupies a lot of memory will cause themVersions
I have tested the above with both
I observed the same behavior in both cases.