diux-dev / cluster

train on AWS
75 stars 15 forks source link

reducing OOM on large batch sizes #50

Closed yaroslavvb closed 6 years ago

yaroslavvb commented 6 years ago

Todo's to fight OOM errors on larger batch sizes:

  1. option to turn off validation until last couple of epochs. It can be used for faster final "blind" run + validation may contribute to OOM errors.

    --skip-validation-until-epoch=0 (validate every epoch)
    --skip-validation-until-epoch=38 (starting after epoch 38)
  2. option to clear GPU cache each time network input size changes: https://pytorch.org/docs/stable/cuda.html?highlight=cache#torch.cuda.empty_cache

    --empty-cache-on-set-data

    To be made default if no regression

  3. Add monitoring of GPU usage from PyTorch stats https://pytorch.org/docs/stable/cuda.html?highlight=cache#memory-management

    memory/memory_allocated
    memory/max_memory_allocated
    memory/memory_cached
    memory/max_memory_cached

    cc @bearpelican

bearpelican commented 6 years ago

We currently clear the cache (torch.cuda.empty_cache) whenever we change data loaders https://github.com/diux-dev/cluster/blob/28ad21f0a81d5d8268a16831c284bc57dbcff098/pytorch/training/train_imagenet_nv.py#L154

I know you can get memory leaks if you hold on to a tensor without converting it - https://github.com/diux-dev/cluster/blob/master/pytorch/training/train_imagenet_nv.py#L370 Perhaps we are doing that somewhere.

yaroslavvb commented 6 years ago

I seem to run out of memory in the backward pass. Also, not sure if it's a coincidence, but both times this happened on the last worker (worker 15 in 16-machine run and worker 3 in 4-machine run)

Epoch: [17][310/313]    Time 0.334 (0.428)      Data 0.000 (0.026)      Loss 6.9066 (6.9062)    Prec@1 0.049 (0.101)    Prec@5 0.391 (0.507)    bw 2.118 2.118
Traceback (most recent call last):
  File "train_imagenet_nv.py", line 645, in <module>
    main()
  File "train_imagenet_nv.py", line 348, in main
    train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
  File "train_imagenet_nv.py", line 435, in train
    loss.backward()
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA error: out of memory
yaroslavvb commented 6 years ago

I've been hitting OOM's on 8 machine run. It happens right after switching to 288. Looking at memory usage, it seems it's 12GB right before switching, so perhaps that's cutting it too close

screenshot 2018-08-09 15 46 21
Image size: 288
Batch size: 128
Train Directory: /home/ubuntu/data/imagenet/train
Validation Directory: /home/ubuntu/data/imagenet/validation
Changing LR from 0.028199999999999996 to 0.018799999999999997
Traceback (most recent call last):
  File "train_imagenet_nv.py", line 662, in <module>
    main()
  File "train_imagenet_nv.py", line 354, in main
    train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
  File "train_imagenet_nv.py", line 445, in train
    loss.backward()
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA error: out of memory
yaroslavvb commented 6 years ago

So I wrapped model forward as follows

        print("before fwd {:,}, {:,}".format(torch.cuda.memory_allocated(), torch.cuda.memory_cached()))
        output = model(input)
        print("after fwd {:,}, {:,}".format(torch.cuda.memory_allocated(), torch.cuda.memory_cached()))

It seems that every time new loader is created, the "cached" memory grows by 200MB until 15GB and OOM. gc and clear cache doesn't help

after fwd 11,444,038,656, 11,612,061,696
before fwd 574,218,240, 1,925,054,464
after fwd 11,705,068,544, 11,957,436,416
before fwd 575,266,816, 13,838,843,904
after fwd 11,706,248,192, 13,838,843,904
before fwd 576,315,392, 13,839,892,480
after fwd 11,708,083,200, 13,839,892,480
before fwd 577,101,824, 13,840,941,056
after fwd 11,708,083,200, 13,840,941,056
Epoch: [0][5/5005]  Time 0.297 (3.011)  Data 0.001 (0.536)  Loss 8.5078 (7.6859)    Prec@1 0.391 (0.078)    Prec@5 0.391 (0.625)    bw 0.000 0.000
before fwd 575,803,392, 13,840,941,056
after fwd 11,707,296,768, 13,840,941,056
before fwd 574,873,600, 13,840,941,056
after fwd 11,705,854,976, 13,840,941,056
before fwd 577,363,968, 13,840,941,056
after fwd 11,708,345,344, 13,840,941,056
before fwd 576,053,248, 13,840,941,056
after fwd 11,707,034,624, 13,840,941,056
before fwd 577,363,968, 13,840,941,056
after fwd 11,709,131,776, 13,840,941,056
Epoch: [0][10/5005] Time 0.298 (1.670)  Data 0.001 (0.283)  Loss 7.7695 (7.9926)    Prec@1 0.000 (0.039)    Prec@5 0.000 (0.391)    bw 0.000 0.000
before fwd 576,196,608, 13,840,941,056
after fwd 11,707,689,984, 13,840,941,056
before prefetcher 421,395,968, 13,840,941,056
after prefetcher 498,468,352, 13,957,595,136
Changing LR from 1.0004395604395604 to 1.20003996003996
before fwd 575,540,736, 14,034,665,472
after fwd 11,707,034,112, 14,034,665,472
before fwd 577,495,040, 14,111,735,808
after fwd 11,708,476,416, 14,111,735,808
before fwd 575,004,672, 14,111,735,808
after fwd 11,705,986,048, 14,111,735,808
before fwd 576,839,680, 14,111,735,808
after fwd 11,707,821,056, 14,111,735,808
before fwd 576,184,320, 14,111,735,808
after fwd 11,707,165,696, 14,111,735,808
Epoch: [1][5/5005]  Time 0.304 (0.871)  Data 0.001 (0.572)  Loss 6.9297 (6.9633)    Prec@1 0.391 (0.156)    Prec@5 0.391 (0.703)    bw 0.000 0.000
before fwd 576,196,608, 14,111,735,808
after fwd 11,707,689,984, 14,111,735,808
before fwd 575,266,816, 14,111,735,808
after fwd 11,706,248,192, 14,111,735,808
before fwd 576,315,392, 14,111,735,808
after fwd 11,707,296,768, 14,111,735,808
before fwd 574,480,384, 14,111,735,808
after fwd 11,705,461,760, 14,111,735,808
before fwd 576,446,464, 14,111,735,808
after fwd 11,707,427,840, 14,111,735,808
Epoch: [1][10/5005] Time 0.309 (0.603)  Data 0.001 (0.304)  Loss 6.9375 (6.9641)    Prec@1 0.000 (0.156)    Prec@5 0.391 (0.664)    bw 0.000 0.000
before fwd 576,720,896, 14,111,735,808
after fwd 11,708,869,632, 14,111,735,808
before prefetcher 422,182,400, 14,111,735,808
after prefetcher 499,254,784, 14,228,389,888
Changing LR from 1.2004395604395603 to 1.40003996003996
before fwd 576,327,168, 14,305,460,224
after fwd 11,708,475,904, 14,305,460,224
before fwd 576,708,608, 14,382,530,560
after fwd 11,707,689,984, 14,382,530,560
before fwd 574,480,384, 14,382,530,560
after fwd 11,705,461,760, 14,382,530,560
before fwd 575,135,744, 14,382,530,560
after fwd 11,706,117,120, 14,382,530,560
before fwd 577,232,896, 14,382,530,560
after fwd 11,708,214,272, 14,382,530,560
Epoch: [2][5/5005]  Time 0.300 (0.812)  Data 0.002 (0.511)  Loss 6.9102 (6.9172)    Prec@1 0.000 (0.078)    Prec@5 0.391 (0.859)    bw 0.000 0.000
before fwd 574,754,816, 14,382,530,560
after fwd 11,706,248,192, 14,382,530,560
before fwd 575,004,672, 14,382,530,560
after fwd 11,705,986,048, 14,382,530,560
before fwd 574,873,600, 14,382,530,560
after fwd 11,705,854,976, 14,382,530,560
before fwd 574,480,384, 14,382,530,560
after fwd 11,705,461,760, 14,382,530,560
before fwd 575,528,960, 14,382,530,560
after fwd 11,706,510,336, 14,382,530,560
Epoch: [2][10/5005] Time 0.297 (0.573)  Data 0.001 (0.274)  Loss 6.9219 (6.9152)    Prec@1 0.391 (0.117)    Prec@5 0.781 (0.703)    bw 0.000 0.000
before fwd 574,623,744, 14,382,530,560
after fwd 11,706,772,480, 14,382,530,560
before prefetcher 420,216,320, 14,382,530,560
after prefetcher 497,288,704, 14,499,184,640
Changing LR from 1.4004395604395605 to 1.6000399600399602
before fwd 574,361,088, 14,576,254,976
after fwd 11,705,854,464, 14,576,254,976
before fwd 574,218,240, 14,653,325,312
after fwd 11,705,199,616, 14,653,325,312
before fwd 577,232,896, 14,653,325,312
after fwd 11,708,214,272, 14,653,325,312
before fwd 575,135,744, 14,653,325,312
after fwd 11,706,117,120, 14,653,325,312
before fwd 576,053,248, 14,653,325,312
after fwd 11,707,034,624, 14,653,325,312
Epoch: [3][5/5005]  Time 0.297 (0.813)  Data 0.001 (0.516)  Loss 6.9102 (6.9156)    Prec@1 0.000 (0.156)    Prec@5 0.781 (0.859)    bw 0.000 0.000
before fwd 573,837,312, 14,653,325,312
after fwd 11,705,461,760, 14,653,325,312
before fwd 574,480,384, 14,653,325,312
after fwd 11,705,461,760, 14,653,325,312
before fwd 576,315,392, 14,653,325,312
after fwd 11,707,296,768, 14,653,325,312
before fwd 574,218,240, 14,653,325,312
after fwd 11,705,199,616, 14,653,325,312
before fwd 575,397,888, 14,653,325,312
after fwd 11,706,379,264, 14,653,325,312
Epoch: [3][10/5005] Time 0.297 (0.573)  Data 0.001 (0.276)  Loss 6.9023 (6.9133)    Prec@1 0.391 (0.117)    Prec@5 0.781 (0.742)    bw 0.000 0.000
before fwd 574,099,456, 14,653,325,312
after fwd 11,705,592,832, 14,653,325,312
before prefetcher 420,740,608, 14,653,325,312
after prefetcher 497,812,992, 847,904,768
after prefetcher 497,812,992, 847,904,768
Changing LR from 1.6004395604395605 to 1.80003996003996
before fwd 574,885,376, 924,975,104
Traceback (most recent call last):
  File "train_imagenet_nv.py", line 690, in <module>
    main()
  File "train_imagenet_nv.py", line 361, in main
    train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
  File "train_imagenet_nv.py", line 439, in train
    output = model(input)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/container.py", line 91, in forward
    input = module(input)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/resnet.py", line 208, in forward
    x = self.layer4(x)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/container.py", line 91, in forward
    input = module(input)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/resnet.py", line 150, in forward
    residual = self.downsample(x)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/container.py", line 91, in forward
    input = module(input)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 308, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: CUDA error: out of memory
yaroslavvb commented 6 years ago

to reproduce OOM in about 30 seconds:

export zone=....
python launch_nv.py --spot --name somejob --params quick_oom
yaroslavvb commented 6 years ago

Removed a bunch of things (https://github.com/diux-dev/cluster/commit/a8cb1b1f66ae1d39236aa7d1d51d3aa9aa6acafc), still OOM on first step of 3rd epoch, both in pytorch_source and DLAMI. Also in non-distributed version.

checkout oom
python launch_nv.py --spot --name refactor-check --params quick_oom
bearpelican commented 6 years ago

Latest pytorch is able to get up to 224 batch size for 224 images. Completely removed prefetcher as you suggested.

Closing this for now. We can reopen if we see more issues