pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.99k stars 6.92k forks source link

Faster R-CNN training OOM at ops/boxes.py _box_inter_union function #7959

Open crazyboy9103 opened 1 year ago

crazyboy9103 commented 1 year ago

🐛 Describe the bug

Training Faster R-CNN on large dataset (~1M of resolution 512x512) fails due to CUDA OOM in RPN. These are the hyperparameters for the experiment:

rpn_pre_nms_top_n_train = 2000
rpn_pre_nms_top_n_test = 2000
rpn_post_nms_top_n_train = 2000
rpn_post_nms_top_n_test = 2000
rpn_nms_thresh = 0.7
rpn_fg_iou_thresh = 0.7
rpn_bg_iou_thresh = 0.3
rpn_batch_size_per_image = 256
rpn_positive_fraction = 0.5
rpn_score_thresh = 0
box_score_thresh = 0.05
box_nms_thresh = 0.1
box_detections_per_img = 100
box_fg_iou_thresh = 0.5
box_bg_iou_thresh = 0.5
box_batch_size_per_image = 512
box_positive_fraction = 0.25
batch_size = 4
   File "/workspace/entrypoint.py", line 182, in <module>
    main(args)
  File "/workspace/entrypoint.py", line 152, in main
    trainer.fit(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 980, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1023, in _run_stage
    self.fit_loop.run()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
    self.advance()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py", line 355, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
    self.advance(data_fetcher)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/training_epoch_loop.py", line 219, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 188, in run
    self._optimizer_step(kwargs.get("batch_idx", 0), closure)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 266, in _optimizer_step
    call._call_lightning_module_hook(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 146, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/module.py", line 1270, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/optimizer.py", line 161, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 231, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 116, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py", line 69, in wrapper
    return wrapped(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 280, in wrapper
    out = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 33, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/adam.py", line 121, in step
    loss = closure()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 103, in _wrap_closure
    closure_result = closure()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 142, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 128, in closure
    step_output = self._step_fn()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 315, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 294, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 380, in training_step
    return self.model.training_step(*args, **kwargs)
  File "/workspace/lightning_modules.py", line 96, in training_step
    loss_dict = self(images, targets)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/lightning_modules.py", line 90, in forward
    return self.model(images, targets)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/workspace/models/detection/builder.py", line 183, in forward
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/models/detection/rpn.py", line 334, in forward
    losses = self._return_loss(anchors, targets, objectness, pred_bbox_deltas)
  File "/workspace/models/detection/rpn.py", line 343, in _return_loss
    labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
  File "/workspace/models/detection/rpn.py", line 193, in assign_targets_to_anchors
    match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
  File "/usr/local/lib/python3.10/dist-packages/torchvision/ops/boxes.py", line 271, in box_iou
    inter, union = _box_inter_union(boxes1, boxes2)
  File "/usr/local/lib/python3.10/dist-packages/torchvision/ops/boxes.py", line 250, in _box_inter_union
    union = area1[:, None] + area2 - inter
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.42 GiB (GPU 0; 15.99 GiB total capacity; 12.84 GiB already allocated; 273.61 MiB free; 14.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

_box_inter_union function in torchvision/ops/boxes.py seems like it consumes a large amount of memory in the tensor operations, when len(boxes1) and len(boxes2) are large. I have altered the code as following to resolve the issue:

def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
    # original 
    # area1 = box_area(boxes1)
    # area2 = box_area(boxes2)

    # lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    # rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    # wh = _upcast(rb - lt).clamp(min=0)  # [N,M,2]
    # inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    # union = area1[:, None] + area2 - inter

    # return inter, union
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    N, M = boxes1.size(0), boxes2.size(0)

    inter = torch.zeros((N, M), device=boxes1.device)
    union = torch.zeros((N, M), device=boxes1.device)

    for i in range(N):
        lt = torch.max(boxes1[i, :2], boxes2[:, :2])  # [M, 2]
        rb = torch.min(boxes1[i, 2:], boxes2[:, 2:])  # [M, 2]
        wh = (rb - lt).clamp(min=0)  # [M, 2]
        curr_inter = wh[:, 0] * wh[:, 1]  # [M]

        inter[i] = curr_inter
        union[i] = area1[i] + area2 - curr_inter

    return inter, union

Below are GPU memory usage before and after the modification. Before modification After modification

It is very odd that the memory usage even increases, but before the modification, GPU memory usage continues to increase. This is clearly not expected behaviour. Can anyone help me figure out what is going on?

Versions

[pip3] numpy==1.25.2 [pip3] pytorch-lightning==2.0.8 [pip3] torch==2.0.1 [pip3] torchinfo==1.8.0 [pip3] torchmetrics==1.0.2 [pip3] torchvision==0.15.2 [pip3] triton==2.0.0

pmeier commented 1 year ago

I haven't checked in detail, but from skimming your patch, it seems you have replaced batch processing with a loop. Meaning you are trading performance for memory.

As for the graphs, I don't know how the were created. It is odd to me that in both the memory changes quite heavily over time. For example, what is happening in the lower one, i.e. the one with your patch, at minute 10? And again at minute 20?

crazyboy9103 commented 1 year ago

I haven't checked in detail, but from skimming your patch, it seems you have replaced batch processing with a loop. Meaning you are trading performance for memory.

Yes, by replacing the batch processing with a loop i was able to reduce the peak memory and avoid OOM.

As for the graphs, I don't know how the were created. It is odd to me that in both the memory changes quite heavily over time. For example, what is happening in the lower one, i.e. the one with your patch, at minute 10? And again at minute 20?

The graphs were automatically generated from wandb. I also felt that it was odd, but i haven't found a reason for that. I'm using pytorch lightning to train the model, and wandblogger to log metrics, images, etc. It seems like sth is causing a memory leak. I've reviewed my code for couple of days now but not found any parts of the code that can cause the odd behavior, as it's got nothing different from torchvision implementation.

Aside from the increase, _box_inter_union function has to be modified somehow as it certainly increases the peak memory for a large number of boxes and can potentially cause more frequent OOM.

I'll try to figure why GPU memory usage increases on my own, leaving the issue open.