facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.31k stars 7.45k forks source link

A simple trick for a fully deterministic ROIAlign, and thus MaskRCNN training and inference #4723

Open ASDen opened 1 year ago

ASDen commented 1 year ago

Non-determinism of MaskRCNN

There have been a lot of discussions and inquiries in this repo about a fully deterministic MaskRCNN e.g. #4260, #3203 , #2615, #2480, and also on other detection repositories (e.g. MMDetection here and here and also torchvision here). Unfortunately, even after seeding everything and setting Pytorch's deterministic flags, results are still non-repeatable.

It boils down to the fact that some of the used Pytorch / torchvision ops doesn't have a deterministic GPU implementation (most notably, due to using atomicAdd in the backward pass). So, the only solution is to train for as long as possible to reduce variance in the results. It is worth noting that not only training, but also evaluation (see #2480) of MaskRCNN (and actually most detectron2 models) is not deterministic

Based on the minimal example in #4260, I made an analysis on the ops used for MaskRCNN and found that the main reason of non-determinism is the backward pass of ROIAlign (see here).

Proposed solution

I am here proposing a simple trick that makes ROIAlign practically fully reproducible, without touching the cuda kernel!! it introduces trivial additional memory and computation. It can be summarized as:

In terms of code, this is translated to simply modifying this function call to

return roi_align(
    input.half().double(),
    rois.half().double(),
    self.output_size,
    self.spatial_scale,
    self.sampling_ratio,
    self.aligned,
).to(dtype=input.dtype)

Test

The conversion to double results in a trivial increase in memory & computation, but performing it after the truncation, significantly increases reproducibility.

This solution was tested and found fully deterministic (losses values, and evaluation results on COCO) upto tens of thousands of steps (using same code as in #4260) for:

Note on A100

Ampere by default uses TF32 format for tensor-core computations, which means that the above truncation is done implicitly! so on Ampere based devices it is enough just to cast to double, i.e.

return roi_align(
    input.double(),
    rois.double(),
    self.output_size,
    self.spatial_scale,
    self.sampling_ratio,
    self.aligned,
).to(dtype=input.dtype)

Note: This is the default mode for PyTorch, but if TF32 is disabled for some reason (i.e. torch.backends.cudnn.allow_tf32 = False) then the above truncation with .half() is still necessary

Note

Would love to hear what people think about this! @ppwwyyxx @fmassa

muncok commented 1 year ago

Thank you for sharing your tip.

I have tried your solution, but the loss values across runs were not identical.

You said

This solution was tested and found fully deterministic (losses values, and evaluation results on COCO) upto tens of thousands of steps (using same code as in https://github.com/facebookresearch/detectron2/issues/4260) for:

Was the losses values perfectly identical in your experiments?

Could you please let me know what version of Pytorch you are using?

LucQueen commented 1 year ago

Thank you @ASDen but i tried your solution , i got a error when i train on A100

  warnings.warn(
/opt/conda/lib/python3.8/site-packages/torch/functional.py:599: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  /opt/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:2299.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Traceback (most recent call last):
  File "tools/train_net.py", line 170, in <module>
    launch(
  File "~/detectron2/detectron2/engine/launch.py", line 82, in launch
    main_func(*args)
  File "tools/train_net.py", line 150, in main
    return trainer.train()
  File "~/detectron2/detectron2/engine/defaults.py", line 484, in train
    super().train(self.start_iter, self.max_iter)
  File "~/detectron2/detectron2/engine/train_loop.py", line 149, in train
    self.run_step()
  File "~/detectron2/detectron2/engine/defaults.py", line 494, in run_step
    self._trainer.run_step()
  File "~/detectron2/detectron2/engine/train_loop.py", line 274, in run_step
    loss_dict = self.model(data)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/detectron2/detectron2/modeling/meta_arch/rcnn.py", line 167, in forward
    _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/detectron2/detectron2/modeling/roi_heads/roi_heads.py", line 739, in forward
    losses = self._forward_box(features, proposals)
  File "~/detectron2/detectron2/modeling/roi_heads/roi_heads.py", line 798, in _forward_box
    box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/detectron2/detectron2/modeling/poolers.py", line 261, in forward
    output.index_put_((inds,), pooler(x[level], pooler_fmt_boxes_level))
RuntimeError: expected scalar type Double but found Float
muncok commented 1 year ago

Thank you for sharing your tip.

I have tried your solution, but the loss values across runs were not identical.

You said

This solution was tested and found fully deterministic (losses values, and evaluation results on COCO) upto tens of thousands of steps (using same code as in #4260) for:

Was the losses values perfectly identical in your experiments?

Could you please let me know what version of Pytorch you are using?

Oh, my code starts to be reproducible after applying the following code snippet.

Thank you!!

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
GeoffreyChen777 commented 1 year ago

You are a real lifesaver.

luca-serra commented 11 months ago

Works like a charm! The training is fully reproducible and the performance does not worsen. Thanks for this hack!

katherinegls commented 9 months ago

Thank you @ASDen. I have tried your solution on the Sparse R-CNN based on the detectron2 on GPU 3090. Although the training is fully reproducible, the loss did not decrease. My modifications to RoI Align are as follows. return roi_align( input.half().double(), rois.half().double(), self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned, ).to(dtype=input.dtype) How to solve the problem of loss not decreasing?