pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.84k stars 22.6k forks source link

Extend docs - Fixing out of memory with python garbage collection #95462

Open voegtlel opened 1 year ago

voegtlel commented 1 year ago

📚 The doc issue

I suggest to extend https://pytorch.org/docs/stable/notes/faq.html to include a section about the python garbage collector, which is missing.

To be specific, we've now encountered several scenarios where we hit the following case:

Due to other processing during the training loop (in one case the data loader was loading complex data structures, in the other case the main loop created complex tensorboard logs), the garbage collector is triggered for gen0 and gen1, moving a few of those large tensors to gen2 (only about 10-20 objects, but a few hundred megabytes on GPU). Thus, these large tensors will not be released immediately. After adding a gc.collect() every nth iteration (or specifically, after each of these costly iterations), the leak was gone. An alternative/addition could be to call gc.freeze() before the loop, which will clear the gen2, thus reducing the number of required objects to perform gc gen2 and also reduce the gen2 collection overhead.

Detailed explanation: Because gen2 gc will only trigger if long_lived_pending / long_lived_total > 0.25 and there are already a few hundred thousand objects in gen2 after startup, it will not trigger before a lot of these large tensors end up in gen2. As an example with numbers: Assuming there are 350.000 objects in long_lived_total (which is approx the number I see when in the inference loop), you'd need 87.500 new objects in gen2 before gen2 gc would be triggered, which would need ~5k iterations. This will yield an OOM before ever being triggered.

Suggest a potential alternative/fix

The suggested addition:

Python garbage collection: If you create many python objects during your loop(s), consider calling the Python garbage collector periodically with gc.collect() and/or freeze objects to keep generation 2 small by calling gc.freeze() after initialization before entering your loop(s). Read more about python garbage collection here.

Thanks for considering this addition!

cc @svekars @carljparker

ezyang commented 1 year ago

Sure, send us a PR