open-mmlab / mmrotate

OpenMMLab Rotated Object Detection Toolbox and Benchmark
https://mmrotate.readthedocs.io/en/latest/
Apache License 2.0
1.88k stars 556 forks source link

[1.x] CUDA of out memory for RTMDet-R (tiny) with 24GB of VRAM and batch_size=1 #745

Open kikefdezl opened 1 year ago

kikefdezl commented 1 year ago

Prerequisite

Task

I have modified the scripts/configs, or I'm working on my own tasks/models/datasets.

Branch

1.x branch https://github.com/open-mmlab/mmrotate/tree/1.x

Environment

sys.platform: linux Python: 3.8.10 (default, Nov 14 2022, 12:59:47) [GCC 9.4.0] CUDA available: True numpy_random_seed: 2147483648 GPU 0: NVIDIA A10G CUDA_HOME: /usr/local/cuda NVCC: Cuda compilation tools, release 11.6, V11.6.124 GCC: x86_64-linux-gnu-gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 PyTorch: 1.13.1+cu117 PyTorch compiling details: PyTorch built with:

TorchVision: 0.14.1+cu117 OpenCV: 4.7.0 MMEngine: 0.3.0 MMRotate: 1.0.0rc1+5d0491c

Reproduces the problem - code sample

This is the full training config, based on rotated_rtmdet_tiny-3x-dota.py. It's the parsed version that is saved at work-dir when training:

default_scope = 'mmrotate'
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=50),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=12, max_keep_ckpts=3),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='mmdet.DetVisualizationHook'))
env_cfg = dict(
    cudnn_benchmark=False,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
    type='RotLocalVisualizer',
    vis_backends=[dict(type='LocalVisBackend')],
    name='visualizer')
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
log_level = 'INFO'
load_from = None
resume = False
custom_hooks = [
    dict(type='mmdet.NumClassCheckHook'),
    dict(
        type='EMAHook',
        ema_type='mmdet.ExpMomentumEMA',
        momentum=0.0002,
        update_buffers=True,
        priority=49)
]
max_epochs = 36
base_lr = 0.00025
interval = 2
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=36, val_interval=2)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
param_scheduler = [
    dict(
        type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
        end=1000),
    dict(
        type='CosineAnnealingLR',
        eta_min=1.25e-05,
        begin=18,
        end=36,
        T_max=18,
        by_epoch=True,
        convert_to_iter_based=True)
]
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=0.00025, weight_decay=0.05),
    paramwise_cfg=dict(
        norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
dataset_type = 'DOTADataset'
data_root = 'data/'
classes = ('car', 'tree', 'building')
file_client_args = dict(backend='disk')
train_pipeline = [
    dict(
        type='mmdet.LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
    dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
    dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
    dict(
        type='mmdet.RandomFlip',
        prob=0.75,
        direction=['horizontal', 'vertical', 'diagonal']),
    dict(
        type='RandomRotate',
        prob=0.5,
        angle_range=180,
        rect_obj_labels=[9, 11]),
    dict(
        type='mmdet.Pad', size=(1024, 1024),
        pad_val=dict(img=(114, 114, 114))),
    dict(type='mmdet.PackDetInputs')
]
val_pipeline = [
    dict(
        type='mmdet.LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
    dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
    dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
    dict(
        type='mmdet.Pad', size=(1024, 1024),
        pad_val=dict(img=(114, 114, 114))),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]
test_pipeline = [
    dict(
        type='mmdet.LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
    dict(
        type='mmdet.Pad', size=(1024, 1024),
        pad_val=dict(img=(114, 114, 114))),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]
train_dataloader = dict(
    batch_size=1,
    num_workers=1,
    persistent_workers=False,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=None,
    pin_memory=False,
    dataset=dict(
        type='DOTADataset',
        data_root='data/',
        ann_file='dota_split_train/annfiles/',
        data_prefix=dict(img_path='dota_split_train/images/'),
        metainfo=dict(classes = ('car', 'tree', 'building')),
        img_shape=(1024, 1024),
        filter_cfg=dict(filter_empty_gt=True),
        pipeline=[
            dict(
                type='mmdet.LoadImageFromFile',
                file_client_args=dict(backend='disk')),
            dict(
                type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
            dict(
                type='ConvertBoxType',
                box_type_mapping=dict(gt_bboxes='rbox')),
            dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
            dict(
                type='mmdet.RandomFlip',
                prob=0.75,
                direction=['horizontal', 'vertical', 'diagonal']),
            dict(
                type='RandomRotate',
                prob=0.5,
                angle_range=180,
                rect_obj_labels=[9, 11]),
            dict(
                type='mmdet.Pad',
                size=(1024, 1024),
                pad_val=dict(img=(114, 114, 114))),
            dict(type='mmdet.PackDetInputs')
        ]))
val_dataloader = dict(
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type='DOTADataset',
        data_root='data/',
        ann_file='dota_split_val/annfiles/',
        data_prefix=dict(img_path='dota_split_val/images/'),
        metainfo=dict(classes = ('car', 'tree', 'building')),
        img_shape=(1024, 1024),
        test_mode=True,
        pipeline=[
            dict(
                type='mmdet.LoadImageFromFile',
                file_client_args=dict(backend='disk')),
            dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
            dict(
                type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
            dict(
                type='ConvertBoxType',
                box_type_mapping=dict(gt_bboxes='rbox')),
            dict(
                type='mmdet.Pad',
                size=(1024, 1024),
                pad_val=dict(img=(114, 114, 114))),
            dict(
                type='mmdet.PackDetInputs',
                meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                           'scale_factor'))
        ]))
test_dataloader = dict(
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type='DOTADataset',
        data_root='data/',
        ann_file='dota_split_val/annfiles/',
        data_prefix=dict(img_path='dota_split_val/images/'),
        metainfo=dict(classes = ('car', 'tree', 'building')),
        img_shape=(1024, 1024),
        test_mode=True,
        pipeline=[
            dict(
                type='mmdet.LoadImageFromFile',
                file_client_args=dict(backend='disk')),
            dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
            dict(
                type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
            dict(
                type='ConvertBoxType',
                box_type_mapping=dict(gt_bboxes='rbox')),
            dict(
                type='mmdet.Pad',
                size=(1024, 1024),
                pad_val=dict(img=(114, 114, 114))),
            dict(
                type='mmdet.PackDetInputs',
                meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                           'scale_factor'))
        ]))
val_evaluator = dict(type='DOTAMetric', metric='mAP')
test_evaluator = dict(type='DOTAMetric', metric='mAP')
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth'
angle_version = 'le90'
model = dict(
    type='mmdet.RTMDet',
    data_preprocessor=dict(
        type='mmdet.DetDataPreprocessor',
        mean=[103.53, 116.28, 123.675],
        std=[57.375, 57.12, 58.395],
        bgr_to_rgb=False,
        boxtype2tensor=False,
        batch_augments=None),
    backbone=dict(
        type='mmdet.CSPNeXt',
        arch='P5',
        expand_ratio=0.5,
        deepen_factor=0.167,
        widen_factor=0.375,
        channel_attention=True,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU'),
        init_cfg=dict(
            type='Pretrained',
            prefix='backbone.',
            checkpoint=
            'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth'
        )),
    neck=dict(
        type='mmdet.CSPNeXtPAFPN',
        in_channels=[96, 192, 384],
        out_channels=96,
        num_csp_blocks=1,
        expand_ratio=0.5,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU')),
    bbox_head=dict(
        type='RotatedRTMDetSepBNHead',
        num_classes=3,
        in_channels=96,
        stacked_convs=2,
        feat_channels=96,
        angle_version='le90',
        anchor_generator=dict(
            type='mmdet.MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
        bbox_coder=dict(type='DistanceAnglePointCoder', angle_version='le90'),
        loss_cls=dict(
            type='mmdet.QualityFocalLoss',
            use_sigmoid=True,
            beta=2.0,
            loss_weight=1.0),
        loss_bbox=dict(type='RotatedIoULoss', mode='linear', loss_weight=2.0),
        with_objectness=False,
        exp_on_reg=False,
        share_conv=True,
        pred_kernel_size=1,
        use_hbbox_loss=False,
        scale_angle=False,
        loss_angle=None,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU')),
    train_cfg=dict(
        assigner=dict(
            type='mmdet.DynamicSoftLabelAssigner',
            iou_calculator=dict(type='RBboxOverlaps2D'),
            topk=13),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    test_cfg=dict(
        nms_pre=2000,
        min_bbox_size=0,
        score_thr=0.05,
        nms=dict(type='nms_rotated', iou_threshold=0.1),
        max_per_img=2000))
launcher = 'none'
work_dir = '/home/ubuntu/trainlogs/rtm_det_testing/'

Reproduces the problem - command or script

# Running the config above
python3 tools/train.py data/rotated_rtmdet_tiny-3x-dota.py 

Reproduces the problem - error message

02/27 15:29:46 - mmengine - INFO - Epoch(train) [1][50/24063]  lr: 1.2265e-05  eta: 1 day, 17:52:43  time: 0.1740  data_time: 0.0044  memory: 7930  loss: 1.1786  loss_cls: 0.7322  loss_bbox: 0.4465
02/27 15:29:52 - mmengine - INFO - Epoch(train) [1][100/24063]  lr: 2.4777e-05  eta: 1 day, 11:50:19  time: 0.1239  data_time: 0.0030  memory: 7562  loss: 1.5051  loss_cls: 0.7498  loss_bbox: 0.7553
02/27 15:29:58 - mmengine - INFO - Epoch(train) [1][150/24063]  lr: 3.7289e-05  eta: 1 day, 10:13:25  time: 0.1288  data_time: 0.0035  memory: 6264  loss: 1.7442  loss_cls: 0.7630  loss_bbox: 0.9813
02/27 15:30:05 - mmengine - INFO - Epoch(train) [1][200/24063]  lr: 4.9802e-05  eta: 1 day, 9:36:21  time: 0.1320  data_time: 0.0027  memory: 15644  loss: 1.8564  loss_cls: 0.7321  loss_bbox: 1.1243
02/27 15:30:11 - mmengine - INFO - Epoch(train) [1][250/24063]  lr: 6.2314e-05  eta: 1 day, 9:02:43  time: 0.1281  data_time: 0.0029  memory: 11892  loss: 2.0365  loss_cls: 0.8459  loss_bbox: 1.1906
02/27 15:30:17 - mmengine - INFO - Epoch(train) [1][300/24063]  lr: 7.4827e-05  eta: 1 day, 8:17:49  time: 0.1187  data_time: 0.0029  memory: 2915  loss: 1.9696  loss_cls: 0.8354  loss_bbox: 1.1342
02/27 15:30:24 - mmengine - INFO - Epoch(train) [1][350/24063]  lr: 8.7339e-05  eta: 1 day, 8:12:50  time: 0.1319  data_time: 0.0029  memory: 8907  loss: 1.9075  loss_cls: 0.8541  loss_bbox: 1.0533
02/27 15:30:30 - mmengine - INFO - Epoch(train) [1][400/24063]  lr: 9.9851e-05  eta: 1 day, 7:57:50  time: 0.1257  data_time: 0.0030  memory: 3041  loss: 1.9364  loss_cls: 0.8330  loss_bbox: 1.1034
02/27 15:30:36 - mmengine - INFO - Epoch(train) [1][450/24063]  lr: 1.1236e-04  eta: 1 day, 7:37:56  time: 0.1206  data_time: 0.0029  memory: 6682  loss: 1.9127  loss_cls: 0.8549  loss_bbox: 1.0578
02/27 15:30:42 - mmengine - INFO - Epoch(train) [1][500/24063]  lr: 1.2488e-04  eta: 1 day, 7:29:34  time: 0.1258  data_time: 0.0029  memory: 6428  loss: 1.9565  loss_cls: 0.8740  loss_bbox: 1.0824
02/27 15:30:49 - mmengine - INFO - Epoch(train) [1][550/24063]  lr: 1.3739e-04  eta: 1 day, 7:22:21  time: 0.1255  data_time: 0.0029  memory: 2859  loss: 1.9069  loss_cls: 0.8359  loss_bbox: 1.0710
02/27 15:30:55 - mmengine - INFO - Epoch(train) [1][600/24063]  lr: 1.4990e-04  eta: 1 day, 7:09:08  time: 0.1196  data_time: 0.0028  memory: 4678  loss: 1.8457  loss_cls: 0.8506  loss_bbox: 0.9951
02/27 15:31:01 - mmengine - INFO - Epoch(train) [1][650/24063]  lr: 1.6241e-04  eta: 1 day, 7:12:57  time: 0.1331  data_time: 0.0032  memory: 7908  loss: 1.8983  loss_cls: 0.8788  loss_bbox: 1.0196
02/27 15:31:07 - mmengine - INFO - Epoch(train) [1][700/24063]  lr: 1.7493e-04  eta: 1 day, 7:00:26  time: 0.1178  data_time: 0.0028  memory: 2636  loss: 1.8368  loss_cls: 0.8276  loss_bbox: 1.0092
02/27 15:31:13 - mmengine - INFO - Epoch(train) [1][750/24063]  lr: 1.8744e-04  eta: 1 day, 6:50:09  time: 0.1184  data_time: 0.0029  memory: 2499  loss: 1.8550  loss_cls: 0.8018  loss_bbox: 1.0532
02/27 15:31:19 - mmengine - INFO - Epoch(train) [1][800/24063]  lr: 1.9995e-04  eta: 1 day, 6:42:01  time: 0.1194  data_time: 0.0028  memory: 4978  loss: 1.8916  loss_cls: 0.8698  loss_bbox: 1.0218
02/27 15:31:25 - mmengine - INFO - Epoch(train) [1][850/24063]  lr: 2.1246e-04  eta: 1 day, 6:37:57  time: 0.1230  data_time: 0.0029  memory: 4115  loss: 1.8995  loss_cls: 0.8284  loss_bbox: 1.0711
02/27 15:31:32 - mmengine - INFO - Epoch(train) [1][900/24063]  lr: 2.2498e-04  eta: 1 day, 6:38:24  time: 0.1281  data_time: 0.0030  memory: 9617  loss: 1.8854  loss_cls: 0.8190  loss_bbox: 1.0665
02/27 15:31:38 - mmengine - INFO - Epoch(train) [1][950/24063]  lr: 2.3749e-04  eta: 1 day, 6:33:14  time: 0.1208  data_time: 0.0031  memory: 2299  loss: 1.8980  loss_cls: 0.8585  loss_bbox: 1.0395
02/27 15:31:44 - mmengine - INFO - Exp name: rotated_rtmdet_tiny-3x-dota_20230227_152918
02/27 15:31:44 - mmengine - INFO - Epoch(train) [1][1000/24063]  lr: 2.5000e-04  eta: 1 day, 6:29:47  time: 0.1225  data_time: 0.0029  memory: 5355  loss: 1.8538  loss_cls: 0.8620  loss_bbox: 0.9918
02/27 15:31:50 - mmengine - INFO - Epoch(train) [1][1050/24063]  lr: 2.5000e-04  eta: 1 day, 6:26:33  time: 0.1223  data_time: 0.0027  memory: 7432  loss: 1.8448  loss_cls: 0.8731  loss_bbox: 0.9717
02/27 15:31:56 - mmengine - INFO - Epoch(train) [1][1100/24063]  lr: 2.5000e-04  eta: 1 day, 6:20:44  time: 0.1179  data_time: 0.0026  memory: 2991  loss: 1.8463  loss_cls: 0.8612  loss_bbox: 0.9851
02/27 15:32:02 - mmengine - INFO - Epoch(train) [1][1150/24063]  lr: 2.5000e-04  eta: 1 day, 6:21:24  time: 0.1275  data_time: 0.0026  memory: 6617  loss: 1.8711  loss_cls: 0.9169  loss_bbox: 0.9542
02/27 15:32:08 - mmengine - INFO - Epoch(train) [1][1200/24063]  lr: 2.5000e-04  eta: 1 day, 6:15:51  time: 0.1173  data_time: 0.0028  memory: 4851  loss: 1.8461  loss_cls: 0.9038  loss_bbox: 0.9423
02/27 15:32:14 - mmengine - INFO - Epoch(train) [1][1250/24063]  lr: 2.5000e-04  eta: 1 day, 6:14:43  time: 0.1242  data_time: 0.0030  memory: 4157  loss: 1.8560  loss_cls: 0.8867  loss_bbox: 0.9693
/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmrotate/structures/bbox/rotated_boxes.py:192: UserWarning: The `clip` function does nothing in `RotatedBoxes`.
  warnings.warn('The `clip` function does nothing in `RotatedBoxes`.')
/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Traceback (most recent call last):
  File "tools/train.py", line 122, in <module>
    main()
  File "tools/train.py", line 118, in main
    runner.train()
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1661, in train
    model = self.train_loop.run()  # type: ignore
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmengine/runner/loops.py", line 90, in run
    self.run_epoch()
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmengine/runner/loops.py", line 106, in run_epoch
    self.run_iter(idx, data_batch)
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmengine/runner/loops.py", line 122, in run_iter
    outputs = self.runner.model.train_step(
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
    losses = self._run_forward(data, mode='loss')  # type: ignore
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 320, in _run_forward
    results = self(**data, mode=mode)
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 92, in forward
    return self.loss(inputs, data_samples)
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmdet/models/detectors/single_stage.py", line 78, in loss
    losses = self.bbox_head.loss(x, batch_data_samples)
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmdet/models/dense_heads/base_dense_head.py", line 123, in loss
    losses = self.loss_by_feat(*loss_inputs)
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmrotate/models/dense_heads/rotated_rtmdet_head.py", line 299, in loss_by_feat
    cls_reg_targets = self.get_targets(
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmdet/models/dense_heads/rtmdet_head.py", line 356, in get_targets
    all_assign_metrics, sampling_results_list) = multi_apply(
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmdet/models/utils/misc.py", line 219, in multi_apply
    return tuple(map(list, zip(*map_results)))
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmrotate/models/dense_heads/rotated_rtmdet_head.py", line 395, in _get_targets_single
    assign_result = self.assigner.assign(pred_instances, gt_instances,
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py", line 165, in assign
    soft_cls_cost = F.binary_cross_entropy_with_logits(
  File "/home/ubuntu/mmrotate_tests/venv/lib/python3.8/site-packages/torch/nn/functional.py", line 3162, in binary_cross_entropy_with_logits
    return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.27 GiB (GPU 0; 22.20 GiB total capacity; 17.68 GiB already allocated; 1.25 GiB free; 19.15 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Additional information

I believe this is the smallest available RTMDet-R model. I'm using an NVIDIA A10G with 24GB of VRAM and am running the dataloaders with batch size of 1, so I expected to be able to run this model.

I'm using a custom dataset, which has been pre-cropped to 1024x1024 images.

Is it normal for this model to use so much memory?

zytx121 commented 1 year ago

Hi @kikefdezl, how many targets are there in a single picture of your dataset? This error occurred in the label assign process. The excessive number of gt caused it. Since the RBboxOverlaps2D does not support CPU calculation now, it is recommended that you try using a smaller image size, e,g, 768x768 or 512x512

kikefdezl commented 1 year ago

Hi @zytx121, thanks for the answer. That certainly explains it, as some of my images have a very high amount of boxes.

It was surprising to me, since I've been using ReDet for a long time now and memory usage never depended on target count. Does ReDet use a different method to assign Pred to GT?

Will this potentially be changed in the future for RTMDet?

zytx121 commented 1 year ago

@kikefdezl, DynamicSoftLabelAssigner occupies more VRAM than MaxIoUAssigner. You can reduce the topk parameter to reduce the VRAM, which may damage some performance.

YES, we will consider your feedback in the next version.