lukemelas / EfficientNet-PyTorch

A PyTorch implementation of EfficientNet
Apache License 2.0
7.82k stars 1.52k forks source link

Potential memory leak issue with `memory_efficient=True` #139

Open keven425 opened 4 years ago

keven425 commented 4 years ago

Hi, thank you for implementing and deploying the models. It's awesome.

I run into this issue when using a downstream library that depends on this library. I opened up an issue there. But based on my experiments, the issue might be related to memory_efficient=True used in this library.

I'm pasting the post below. Any thoughts would be appreciated.

========

Currently, if I run FPN with efficientnet-b0, with memory_efficient=True, memory leak seems to happen. This only happens when both conditions are satisfied:

  1. running with FPN. If I run EfficientNet without the decoder, this does not happen
  2. running with set_swish(memory_efficient=True). If set to False, this does not happen

Here is my use case and how I found the bug: I have a function that tries a large batch size, then progressively decrement the batch size until I don't get an cuda OOM error. With other models this function works fine. However, with FPN efficientnet-b0, the function fails to find a valid batch size (batch size decreases all the way to 1).

Here is a repro script:

import torch
import re

import torch.nn as nn
import segmentation_models_pytorch as smp

from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import get_model_params

device = torch.device("cuda:0")
args = get_model_params('efficientnet-b0', override_params=None)
criterion = nn.CrossEntropyLoss()

def run_lower_batch_size(model, label_dim=None):
    for batch_size in range(24, 0, -1):
        try:
            print('batch size: ', batch_size, end='\t')
            input = torch.ones([batch_size, 3, 640, 512], device=device)
            label_shape = [batch_size]
            if label_dim:
                label_shape += label_dim
            label = torch.ones(label_shape).type('torch.LongTensor').to(device)
            output = model(input)
            loss = criterion(output, label)
            loss.backward()
            print('succeeded')
            return
        except RuntimeError as e:
            match = re.search(r'CUDA.*out of memory', str(e))
            if match is not None:
                print('failed w/ cuda oom error')
                continue
            else:
                raise e

    raise Exception('did not find valid batch size')

print('\n\nrun encoder only with memory_efficient=True, will NOT fail with cuda oom error\n')
model = EfficientNet(*args)
model.set_swish(memory_efficient=True)
model.to(device)
run_lower_batch_size(model)

print('\n\nrun encoder-decoder with memory_efficient=False, will NOT fail with cuda oom error\n')
model = smp.FPN('efficientnet-b0', classes=2)
model.encoder.set_swish(memory_efficient=False)
model.to(device)
run_lower_batch_size(model, label_dim=[640, 512])

print('\n\nrun encoder-decoder with memory_efficient=True, will ALWAYS fail with cuda oom error\n')
model = smp.FPN('efficientnet-b0', classes=2)
model.encoder.set_swish(memory_efficient=True)
model.to(device)
label_shape = [640, 512]
run_lower_batch_size(model, label_dim=[640, 512])

output:

/home/ubuntu/.local/share/virtualenvs/factory-H3jTBYbX/bin/python /home/ubuntu/Desktop/pycharm-community-2019.1.1/helpers/pydev/pydevd.py --multiproc --qt-support=auto --client 127.0.0.1 --port 44369 --file /home/ubuntu/factory/learn/localize/scripts/swish_mem_leak.py
Connected to pydev debugger (build 191.7141.48)
pydev debugger: process 16718 is connecting

run encoder only with memory_efficient=True, will NOT fail with cuda oom error

batch size:  24 failed w/ cuda oom error
batch size:  23 failed w/ cuda oom error
batch size:  22 failed w/ cuda oom error
batch size:  21 failed w/ cuda oom error
batch size:  20 failed w/ cuda oom error
batch size:  19 failed w/ cuda oom error
batch size:  18 succeeded

run encoder-decoder with memory_efficient=False, will NOT fail with cuda oom error

batch size:  24 failed w/ cuda oom error
batch size:  23 failed w/ cuda oom error
batch size:  22 failed w/ cuda oom error
batch size:  21 failed w/ cuda oom error
batch size:  20 failed w/ cuda oom error
batch size:  19 failed w/ cuda oom error
batch size:  18 failed w/ cuda oom error
batch size:  17 failed w/ cuda oom error
batch size:  16 failed w/ cuda oom error
batch size:  15 failed w/ cuda oom error
batch size:  14 failed w/ cuda oom error
batch size:  13 failed w/ cuda oom error
batch size:  12 succeeded

run encoder-decoder with memory_efficient=True, will ALWAYS fail with cuda oom error

batch size:  24 failed w/ cuda oom error
batch size:  23 failed w/ cuda oom error
batch size:  22 failed w/ cuda oom error
batch size:  21 failed w/ cuda oom error
batch size:  20 failed w/ cuda oom error
batch size:  19 failed w/ cuda oom error
batch size:  18 failed w/ cuda oom error
batch size:  17 failed w/ cuda oom error
batch size:  16 failed w/ cuda oom error
batch size:  15 failed w/ cuda oom error
batch size:  14 failed w/ cuda oom error
batch size:  13 failed w/ cuda oom error
batch size:  12 failed w/ cuda oom error
batch size:  11 failed w/ cuda oom error
batch size:  10 failed w/ cuda oom error
batch size:  9  failed w/ cuda oom error
batch size:  8  failed w/ cuda oom error
batch size:  7  failed w/ cuda oom error
batch size:  6  failed w/ cuda oom error
batch size:  5  failed w/ cuda oom error
batch size:  4  failed w/ cuda oom error
batch size:  3  failed w/ cuda oom error
batch size:  2  failed w/ cuda oom error
Traceback (most recent call last):
  File "/home/ubuntu/Desktop/pycharm-community-2019.1.1/helpers/pydev/pydevd.py", line 1758, in <module>
    main()
  File "/home/ubuntu/Desktop/pycharm-community-2019.1.1/helpers/pydev/pydevd.py", line 1752, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/home/ubuntu/Desktop/pycharm-community-2019.1.1/helpers/pydev/pydevd.py", line 1147, in run
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/ubuntu/Desktop/pycharm-community-2019.1.1/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/ubuntu/factory/learn/localize/scripts/swish_mem_leak.py", line 57, in <module>
    run_lower_batch_size(model, label_dim=[640, 512])
  File "/home/ubuntu/factory/learn/localize/scripts/swish_mem_leak.py", line 37, in run_lower_batch_size
    raise Exception('did not find valid batch size')
Exception: did not find valid batch size
batch size:  1  failed w/ cuda oom error

Process finished with exit code 1
keven425 commented 4 years ago

@lukemelas any thoughts?