Sense-X / Co-DETR

[ICCV 2023] DETRs with Collaborative Hybrid Assignments Training
MIT License
950 stars 100 forks source link

Co-DETR/mmdet/datasets/pipelines/transforms.py", line 2825, in __call__ assert 'mix_results' in results #114

Closed hubhub086 closed 5 months ago

hubhub086 commented 5 months ago

Hi!
When I use the command python tools/train.py configs/co_detr_vit/co_dino_5scale_lsj_vit_large_lvis.py to train with customized dataset(1024*1024) , I have this problems:

...
2024-03-03 17:33:41,229 - mmdet - INFO - Set random seed to 1977542892, deterministic: False
======== shape of rope freq torch.Size([256, 64]) ========
======== shape of rope freq torch.Size([9216, 64]) ========
...
2024-03-03 17:33:53,514 - mmdet - INFO - workflow: [('train', 1)], max: 16 epochs
2024-03-03 17:33:53,514 - mmdet - INFO - Checkpoints will be saved to ~/projects/Co-DETR/work_dirs/co_dino_5scale_lsj_vit_large_lvis by HardDiskBackend.
...
  _AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "~/miniconda3/envs/mmdec225/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "~/miniconda3/envs/mmdec225/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "~/miniconda3/envs/mmdec225/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "~/projects/Co-DETR/mmdet/datasets/custom.py", line 220, in __getitem__
    data = self.prepare_train_img(idx)
  File "~/projects/Co-DETR/mmdet/datasets/custom.py", line 243, in prepare_train_img
    return self.pipeline(results)
  File "~/projects/Co-DETR/mmdet/datasets/pipelines/compose.py", line 41, in __call__
    data = t(data)
  File "~/projects/Co-DETR/mmdet/datasets/pipelines/transforms.py", line 2825, in __call__
    assert 'mix_results' in results
AssertionError

Can you give me some advices? Thanks!

Here are the modified configs:

_base_ = [
    '../_base_/datasets/coco_detection.py',
    '../_base_/default_runtime.py'
]
checkpoint_config = dict(interval=1)
resume_from = None
load_from = None
pretrained = None
window_block_indexes = (
    list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
)
residual_block_indexes = []

num_dec_layer = 6
lambda_2 = 2.0

model = dict(
    type='CoDETR',
    with_attn_mask=False,
    backbone=dict(
        type='ViT',
        img_size=1536,
        pretrain_img_size=512,
        patch_size=16,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4*2/3,
        drop_path_rate=0.3,
        window_size=16,
        window_block_indexes=window_block_indexes,
        residual_block_indexes=residual_block_indexes,
        qkv_bias=True,
        use_act_checkpoint=True,
        use_lsj=True,
        init_cfg=dict(type='Pretrained', checkpoint='pretrain/eva02_L_pt_m38m_p14to16.pt')),
    ...)

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
# image_size = (1536, 1536)
image_size = (1024, 1024)
load_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(
        type='Resize',
        img_scale=image_size,
        ratio_range=(0.1, 2.0),
        multiscale_mode='range',
        keep_ratio=True),
    dict(
        type='RandomCrop',
        crop_type='absolute_range',
        crop_size=image_size,
        recompute_bbox=True,
        allow_negative_crop=True),
    dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))),
]
train_pipeline = [
    dict(type='CopyPaste', max_num_pasted=100),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=image_size,
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]

data_root = '/mnt/e/data/'
classes = ('target',)

data = dict(
    samples_per_gpu=1,
    workers_per_gpu=1,
    train=dict(
        ann_file=data_root + 'annotation_train.json',
        img_prefix=data_root + 'images/train/',
        classes=classes,
        filter_empty_gt=False, 
        pipeline=train_pipeline),
    val=dict(
        ann_file=data_root + 'annotation_val.json',
        img_prefix=data_root + 'images/val/',
        classes=classes,
        pipeline=test_pipeline),
    test=dict(
        ann_file=data_root + 'annotation_test.json',
        img_prefix=data_root + 'images/test/',
        classes=classes,
        pipeline=test_pipeline))

evaluation = dict(metric='bbox')
dist_params = dict(backend='nccl')

# learning policy
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.01,
    step=[9, 15])
runner = dict(type='EpochBasedRunner', max_epochs=16)

# optimizer
# We use layer-wise learning rate decay, but it has not been implemented.
optimizer = dict(
    type='AdamW',
    lr=5e-5,
    weight_decay=0.05,
    # custom_keys of sampling_offsets and reference_points in DeformDETR
    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)}))
optimizer_config = dict(grad_clip=None)