pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.01k stars 21.99k forks source link

Running JIT trace for many times leads to OOM #86537

Open Co1lin opened 1 year ago

Co1lin commented 1 year ago

πŸ› Describe the bug

Hi! I find that running JIT trace for many different models in one process could lead to OOM. I think there are some memory leak problems here. I guess maybe currently torch doesn't delete a compiled graph even though its life cycle has ended.

I think this issue is worth fixing because in NAS (neural architecture search) people may do this.

Code to reproduce this issue:

import random
import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self, num_layers: int, input_dim: int) -> None:
        super().__init__()
        layer = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.LeakyReLU(),
        )
        self.layers = nn.Sequential(*[
            layer for _ in range(num_layers)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

if __name__ == '__main__':
    import os, psutil
    process = psutil.Process(os.getpid())

    times = 0
    while True:
        num_layers = random.randint(1, 15)
        input_dim = random.randint(1, 20)
        m = MyModule(num_layers, input_dim)
        x = torch.randn((1, input_dim))
        exported = torch.jit.trace(m, x)
        o = exported(x)

        times += 1
        if times % 100 == 0:
            mem_mb = process.memory_info().rss / (1024 ** 2)
            print(f'{times} times, mem usage: {mem_mb} MB')

Logs:

100 times, mem usage: 368.4140625 MB
200 times, mem usage: 512.015625 MB
300 times, mem usage: 648.3984375 MB
400 times, mem usage: 812.625 MB
500 times, mem usage: 959.3203125 MB
600 times, mem usage: 1099.828125 MB
700 times, mem usage: 1243.9453125 MB
800 times, mem usage: 1380.0703125 MB
900 times, mem usage: 1519.7734375 MB
1000 times, mem usage: 1662.6015625 MB
1100 times, mem usage: 1790.734375 MB
1200 times, mem usage: 1940.265625 MB
1300 times, mem usage: 2081.95703125 MB
1400 times, mem usage: 2219.88671875 MB
1500 times, mem usage: 2365.80859375 MB
1600 times, mem usage: 2509.328125 MB
1700 times, mem usage: 2657.62109375 MB
1800 times, mem usage: 2821.58984375 MB
1900 times, mem usage: 2967.51171875 MB
2000 times, mem usage: 3113.43359375 MB
2100 times, mem usage: 3249.04296875 MB
2200 times, mem usage: 3383.87890625 MB
2300 times, mem usage: 3520.26171875 MB
2400 times, mem usage: 3664.63671875 MB
2500 times, mem usage: 3799.21484375 MB
2600 times, mem usage: 3947.01953125 MB
2700 times, mem usage: 4081.85546875 MB
2800 times, mem usage: 4227.51953125 MB
...

If we use the code below, the issue will disappear.

while True:
        num_layers = random.randint(1, 15)
        input_dim = random.randint(1, 20)
        m = MyModule(num_layers, input_dim)
        x = torch.randn((1, input_dim))
        o = m(x) # do not compile

Versions

Collecting environment information...
PyTorch version: 1.12.1+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.13 (main, Aug 25 2022, 23:26:10)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-48-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.6.124
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.3
[pip3] torch==1.12.1+cu116
[pip3] torchaudio==0.12.1+cu116
[pip3] torchvision==0.13.1+cu116
[conda] numpy                     1.23.3                   pypi_0    pypi
[conda] torch                     1.12.1+cu116             pypi_0    pypi
[conda] torchaudio                0.12.1+cu116             pypi_0    pypi
[conda] torchvision               0.13.1+cu116             pypi_0    pypi
sanchitintel commented 1 year ago

Probably a duplicate of #35600.

Co1lin commented 1 year ago

@sanchitintel Thanks for finding the similar issue! I checked my experiments again, and found that just calling jit trace for a same model could trigger this issue. I've corrected the statement above.

I think https://github.com/pytorch/pytorch/issues/35600 is a simpler case of mine because it just compiles the same model each time. It's worth to note that after the simpler one being fixed, this issue with compiling different models each time also needs to be fixed, because it's rare to compile a single model for many times, but it's common to compile many different models when doing neural architecture search. BTW, compiling different models means the models compiled before are "dead" (they are released by Python) but something related to them are not deleted.

sanchitintel commented 1 year ago

Thanks for following up, @Co1lin!

BTW, compiling different models means the models compiled before are "dead"

If you mean there's a loop in which a variable is updated in each iteration, as in the example in #35600's description, then I think you're right!

(they are released by Python) but something related to them are not deleted

I believe you're right that references of objects of traced models are not being deleted somehow. I think fixing this issue would require diving deeper - perhaps Python Garbage collector API can be used to get refcounts, analyze the root-cause, and then subsequently fix the issue in PyTorch code. @voznesenskym might have more insights into this issue.

ganler commented 1 year ago

Just curious, are there any workaround/dirty fixes to route this issue while compiling multiple models in one process if it is hard to be fixed immediately? Thanks!

ganler commented 1 year ago

Found torch._C._jit_clear_class_registry can somehow alleviate this issue but there is still some leakage.

pfeatherstone commented 10 months ago

Is there a potential fix in the works?