facebookresearch / maskrcnn-benchmark

Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch.
MIT License
9.31k stars 2.49k forks source link

Increasing memory consumption when training Retina Net #884

Open buaaMars opened 5 years ago

buaaMars commented 5 years ago

❓ Questions and Help

@fmassa @chengyangfu Hi, Thanks for reading my issue!

When I am training Retina Net, the memory consumption keep increasing until OOM. I have read several issues related to OOM. The causes to OOM assumed by you can be concluded as follows: 1) Different respect ratios make the PyTorch reallocate larger memory. In this case, there could be increase of memory occupation before the end of the first epoch. 2) The number of proposals affects the memory occupation. 3) Large number of gt bboxes requires consume large memory. In this case, the increase of memory occupation also happens before the end of the first epoch.

With the solution provided by you and help from Internet, I make some "improvement" with regard to these potential causes as follows: 1) For every image, I randomly crop a patch of it and resize it. The width and height of the patch and the resizing ratio are kept same for all images. Hence, all the input images have same width and height, except for one image in data set whose original size is smaller than the size of a patch. 2) Since Retina Net I am running is a one-stage detector, there are not proposals at all. 3) I use torch.jit.script according to

https://github.com/facebookresearch/maskrcnn-benchmark/issues/18#issuecomment-466483262

I replace

https://github.com/facebookresearch/maskrcnn-benchmark/blob/95521b646c486faaf6c20af6a3ef08c41fa8f67b/maskrcnn_benchmark/structures/boxlist_ops.py#L53

with

import math
@torch.jit.script
def boxes_iou(box1:torch.Tensor, box2:torch.Tensor):
    N = box1.size(0)
    M = box2.size(0)
    b1x1 = box1[:, 0].unsqueeze(1)  # [N,1]
    b1y1 = box1[:, 1].unsqueeze(1)
    b1x2 = box1[:, 2].unsqueeze(1)
    b1y2 = box1[:, 3].unsqueeze(1)
    b2x1 = box2[:, 0].unsqueeze(0)  # [1,N]
    b2y1 = box2[:, 1].unsqueeze(0)
    b2x2 = box2[:, 2].unsqueeze(0)
    b2y2 = box2[:, 3].unsqueeze(0)
    ltx = torch.max(b1x1, b2x1)  # [N,M]
    lty = torch.max(b1y1, b2y1)
    rbx = torch.min(b1x2, b2x2)
    rby = torch.min(b1y2, b2y2)
    TO_REMOVE = 1
    w = (rbx - ltx + TO_REMOVE).clamp(min=0, max=math.inf)  # [N,M]
    h = (rby - lty + TO_REMOVE).clamp(min=0, max=math.inf)  # [N,M]
    inter = w* h  # [N,M]
    area1 = (b1x2- b1x1 + TO_REMOVE) * (b1y2 - b1y1 + TO_REMOVE)  # [N,1]
    area2 = (b2x2- b2x1 + TO_REMOVE) * (b2y2 - b2y1 + TO_REMOVE)  # [1,M]
    iou = inter / (area1 + area2 - inter)
    return iou
def boxlist_iou(boxlist1, boxlist2):
    if boxlist1.size != boxlist2.size:
        raise RuntimeError(
                "boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2))
    iou = boxes_iou(boxlist1.bbox,boxlist2.bbox)
    return iou

4) I use torch.cuda.empty_cache() after the OOM happens. Also, when a input causes OOM rightly after I run torch.cuda.empty_cache(), I skip it as it has large number of gt bboxes. I replace

https://github.com/facebookresearch/maskrcnn-benchmark/blob/d2698472fc7b07ebf8580604bc86037855e68d45/maskrcnn_benchmark/engine/trainer.py#L71

with

        try:
            loss_dict = model(images, targets)
        except RuntimeError as e1:
            if 'out of memory' in str(e1):
                print('trainer.py(71): out of memory 11111')
                torch.cuda.empty_cache()
                print('torch.cuda.empty_cache(),then try again')
                try:
                    loss_dict = model(images, targets)
                except RuntimeError as e2:
                    if 'out of memory' in str(e2):
                        print('trainer.py(71): out of memory 22222')
                        print('len(targets)',len(targets))
                        print('skip this iteration')
                        for target in targets:
                            print(target)
                        continue
                    else:
                        raise e2
            else:
                raise e1

With these "improvement", memory consumption still increases in training. At first, the memory consumption reported by pytorch is 4 or 5G. Only inputs whose gt bboxes more than 800 cause OOM, and they are skipped after OOM. Gradually, the memory occupation increase. There always is a 2.7G leap after hundreds of iterations. Then, inputs who has 260+ gt bboxes can cause OOM. Thousands of iterations later, memory occupation showed by nvidia-smi approach the max memory of my GPU, that is 12 G. The minimum of the numbers of gt bboxes of the images causing OOM can be 100+. At last, OOM happens at every iteration, even with the images who has only 1 gt bboxes. The training cannot continue any more. I have to kill the program and restart it at the last checkpoint. At the first iterations after restart, the memory consumption is 4 or 5G, as little as the one at the start of training. Instead of rapidly increasing to the large memory occupation when I killed it, the memory occupation gradually increases just like I start the training from 0 iteration. It makes me feel like the training doesn't need that much memory at all.

For the phenomenon I describe above, I believe the increasing memory occupation cannot be simply explained by some attributes of certain inputs because the number of "problem images" keep increasing with the number of epochs increasing, that is to say an input is OK at earlier iterations but causes OOM later. It is more like a memory leak. The "improvement" I have made allow me to run the program with larger input for longer time before the program cannot keep running. But it does not solve the fundamental problems, the increasing memory occupation. I have to restart the program every 10k iterations.

Writing this issue, I reverently ask that: 1) Check my "improvement" and help me to solve the problem. 2) Make a document that helps people to use memory efficiently. 3) Pay attention on the increasing memory occupation problem and try to fix it since this is not the first time it is issued.

Thank you very much!

heiyuxiaokai commented 5 years ago

@fmassa I met the same problem with retinanet, but fcos is working well. @buaaMars I tried cpython to solve it, but too slow. It seems that GPU memory was locked? When it cause OOM, the gpu0 or gpu1 was used. And I must kill the process.

heiyuxiaokai commented 5 years ago

I solved by add torch.cuda.empty_cache() after https://github.com/facebookresearch/maskrcnn-benchmark/blob/55796a04ea770029a80cf5933cc5c3f3f6fa59cf/maskrcnn_benchmark/engine/trainer.py#L77-L85 When gtbox >=500, it cause OOM(need 4.7G memory to caculate iou_rotate),so I set the max gtbox number to 300.

clw5180 commented 5 years ago

I solved by add torch.cuda.empty_cache() after https://github.com/facebookresearch/maskrcnn-benchmark/blob/55796a04ea770029a80cf5933cc5c3f3f6fa59cf/maskrcnn_benchmark/engine/trainer.py#L77-L85

When gtbox >=500, it cause OOM(need 4.7G memory to caculate iou_rotate),so I set the max gtbox number to 300.

Thanks a lot ! How you find this method ?

clw5180 commented 5 years ago

I solved by add torch.cuda.empty_cache() after https://github.com/facebookresearch/maskrcnn-benchmark/blob/55796a04ea770029a80cf5933cc5c3f3f6fa59cf/maskrcnn_benchmark/engine/trainer.py#L77-L85

When gtbox >=500, it cause OOM(need 4.7G memory to caculate iou_rotate),so I set the max gtbox number to 300.

But after added this, it only cost about 2~4G GPU memory, before that it will cost about 10G+ GPU memory, is it normal? By the way, I use the ResXNet101 with FPN.......