IDEA-Research / MaskDINO

[CVPR 2023] Official implementation of the paper "Mask DINO: Towards A Unified Transformer-based Framework for Object Detection and Segmentation"
Apache License 2.0
1.09k stars 97 forks source link

Attempting to copy inputs of <function pairwise_iou at > to CPU due to CUDA OOM #90

Open JINWONMIN opened 11 months ago

JINWONMIN commented 11 months ago

While training a custom dataset in COCO format, we are running out of memory on evaluation and jumping to the CPU.

Does detectron2's algorithm build the test dataloader differently than the train, and puts all mini-batches on the gpu instead of just loading a mini-batch on each gpu?

The images of val2017 directory in my custom dataset is 22GB in size. I know that the val2017 image size of Coco is about 800MB, but on the server GPU I use, Coco runs fine, but the custom dataset copies the model to the CPU due to insufficient memory on the GPU as shown below.

The batch size is 4. GPU is set to 4.

In DINO, this dataset ran fine, but in MASK DINO, I noticed the above issue. Is there a way to fix this?



[08/11 10:31:29 d2.data.datasets.coco]: Loaded 9067 images in COCO format from MaskDINO/datasets/nia/annotations/instances_val2017.json
[08/11 10:31:30 d2.data.build]: Distribution of instances among all 8 categories:
|  category  | #instances   |  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|:----------:|:-------------|
|  sagging   | 639          |  flooding  | 746          |  pinhole   | 957          |
|   crack    | 650          |   swell    | 604          |    weld    | 2499         |
|  scratch   | 2567         |  falloff   | 2351         |            |              |
|   total    | 11013        |            |              |            |              |
[08/11 10:31:30 d2.data.dataset_mapper]: [DatasetMapper] Augmentations used in inference: [ResizeShortestEdge(short_edge_length=(800, 800), max_size=1333, sample_style='choice')]
[08/11 10:31:30 d2.data.common]: Serializing the dataset using: <class 'detectron2.data.common._TorchSerializedList'>
[08/11 10:31:30 d2.data.common]: Serializing 9067 elements to byte tensors and concatenating them all ...
[08/11 10:31:30 d2.data.common]: Serialized dataset takes 8.06 MiB
[08/11 10:31:30 d2.evaluation.evaluator]: Start inference on 2267 batches
/home/xaiplanet/miniconda3/envs/maskdino/lib/python3.8/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2894.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/xaiplanet/miniconda3/envs/maskdino/lib/python3.8/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2894.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/xaiplanet/miniconda3/envs/maskdino/lib/python3.8/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2894.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/xaiplanet/miniconda3/envs/maskdino/lib/python3.8/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2894.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[08/11 10:31:39 d2.utils.memory]: Attempting to copy inputs of <bound method MaskDINO.instance_inference of MaskDINO(
.
.
.
      num_classes: 8
      eos_coef: 0.1
      num_points: 12544
      oversample_ratio: 3.0
      importance_sample_ratio: 0.75
)> to CPU due to CUDA OOM







from detectron2.data import DatasetCatalog, MetadataCatalog

from detectron2.data.datasets import load_sem_seg

from detectron2.data.datasets.coco import register_coco_instances

NIA_CATEGORIES = [ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "sagging"}, {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "flooding"}, {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "pinhole"}, {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "crack"}, {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "swell"}, {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "weld"}, {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "scratch"}, {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "falloff"}, ]

def _get_coco_stuff_meta():

Id 0 is reserved for ignore_label, we change ignore_label for 0

# to 255 in our pre-processing.
stuff_ids = [k["id"] for k in NIA_CATEGORIES]
assert len(stuff_ids) == 8, len(stuff_ids)

# For semantic segmentation, this mapping maps from contiguous stuff id
# (in [0, 91], used in models) to ids in the dataset (used for processing results)
stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}

stuff_classes = [k["name"] for k in NIA_CATEGORIES]

ret = {
    "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
    "stuff_classes": stuff_classes,
}
return ret

def register_custom_coco_dataset(root: str) -> None: root = os.path.join(root, 'nia') meta = _get_coco_stuff_meta() annotations_path = os.path.join(root, 'annotations') register_coco_instances( "nia_train", meta, annotations_path + "/instances_train2017.json", os.path.join(root, 'train2017'), ) register_coco_instances( "nia_valid", meta, annotations_path + "/instances_val2017.json", os.path.join(root, 'val2017'), ) register_coco_instances( "nia_test", meta, annotations_path + "/instances_test2017.json", os.path.join(root, 'test2017') )

_root = os.getenv("DETECTRON2_DATASETS", "datasets") register_custom_coco_dataset(_root)



<br>
kamenrideraway commented 9 months ago

I met the same situation. I find the eval batch is always 1, but I didn't solve this problem.

shiyue-code commented 9 months ago

You can try setting the cfg.INPUT.IMAGE_SIZE parameter to 512, but this may result in a decrease in accuracy.

JINWONMIN commented 9 months ago

I met the same situation. I find the eval batch is always 1, but I didn't solve this problem.

I resolved the issue by reducing the resolution of the original image

ZitengXue commented 6 months ago

我遇到了同样的情况。我发现eval批次始终为1,但我没有解决这个问题。

我通过降低原始图像的分辨率解决了该问题

请问降低分辨率会影响精度吗

Bf-Zheng commented 2 months ago

You can try setting the cfg.INPUT.IMAGE_SIZE parameter to 512, but this may result in a decrease in accuracy.

it seems that during evaluation the image and inference output will be interpolated to the size of the original image, so this operation is only useful during model.train()

Bf-Zheng commented 1 month ago

The dataset mapper used for coco instance segmentation task during evaluation is coco_instance_new_baseline_dataset_mapper.py, which takes the original image as input for model. This will lead to CUDA OOM if the original image size is too large, e.g. 3000*4000.

This issue can be solved by apply a different dataset mapper for the dataloader, you can define a transform for input image and apply the transform to your image and annotation inside the mapper, such as the authors did in detr_dataset_mapper.py:

def build_transform_gen(cfg, is_train):
    if is_train:
        min_size = cfg.INPUT.MIN_SIZE_TRAIN
        max_size = cfg.INPUT.MAX_SIZE_TRAIN
        sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
    else:
        min_size = cfg.INPUT.MIN_SIZE_TEST
        max_size = cfg.INPUT.MAX_SIZE_TEST
        sample_style = "choice"
    if sample_style == "range":
        assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))

    logger = logging.getLogger(__name__)

    tfm_gens = []
    if is_train:
        tfm_gens.append(T.RandomFlip())
    tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
    if is_train:
        logger.info("TransformGens used in training: " + str(tfm_gens))
    return tfm_gens
# in the __call__() method of your dataset mapper
        image, transforms = T.apply_transform_gens(self.tfm_gens, image)
        # the crop transformation has default padding value 0 for segmentation
        padding_mask = transforms.apply_segmentation(padding_mask)
        padding_mask = ~ padding_mask.astype(bool)

        image_shape = image.shape[:2]  # h, w

        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
        # Therefore it's important to use torch.Tensor.
        dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
        dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask))

        # if not self.is_train:
        #     # USER: Modify this if you want to keep them for some reason.
        #     dataset_dict.pop("annotations", None)
        #     return dataset_dict

        if "annotations" in dataset_dict:
            # USER: Modify this if you want to keep them for some reason.
            for anno in dataset_dict["annotations"]:
                # Let's always keep mask
                anno.pop("keypoints", None)

            # USER: Implement additional transformations if you have other types of data
            annos = [
                utils.transform_instance_annotations(obj, transforms, image_shape)
                for obj in dataset_dict.pop("annotations")
                if obj.get("iscrowd", 0) == 0
            ]

Then the CUDA OOM issue should be solved already.