open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.06k stars 9.37k forks source link

Training stuck after index is created when training from checkpoint #11534

Open jdtrellesm opened 6 months ago

jdtrellesm commented 6 months ago

When I try to start training from my latest checkpoint, process stays stuck indefinitely after index is created. This is the last message the terminal output:

Advance dataloader 26667 steps to skip data that has already been trained mmdetection.

When I run the same config file, but just change resume=False, it runs as expected from scratch. This is my confirmation file:

auto_scale_lr = dict(base_batch_size=16, enable=False) backend_args = None classes = ('Car', 'bike', 'Motorcycle', 'boat', 'airplane', 'street', 'person', 'tree') batch_augments = [ dict( img_pad_value=0, mask_pad_value=0, pad_mask=True, pad_seg=False, size=( 1024, 1024, ), type='BatchFixedSizePad'), ] data_preprocessor = dict( batch_augments=[ dict( img_pad_value=0, mask_pad_value=0, pad_mask=True, pad_seg=False, size=( 1024, 1024, ), type='BatchFixedSizePad'), ], bgr_to_rgb=True, mask_pad_value=0, mean=[ 123.675, 116.28, 103.53, ], pad_mask=True, pad_seg=False, pad_size_divisor=32, seg_pad_value=255, std=[ 58.395, 57.12, 57.375, ], type='DetDataPreprocessor') data_root = './data' dataset_type = 'CocoDataset' default_hooks = dict( checkpoint=dict( by_epoch=False, interval=25000, max_keep_ckpts=3, save_last=True, type='CheckpointHook'), logger=dict(interval=50, type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), sampler_seed=dict(type='DistSamplerSeedHook'), timer=dict(type='IterTimerHook'), visualization=dict(type='DetVisualizationHook')) default_scope = 'mmdet' dynamic_intervals = [ ( 365001, 368750, ), ] embed_multi = dict(decay_mult=0.0, lr_mult=1.0) env_cfg = dict( cudnn_benchmark=False, dist_cfg=dict(backend='nccl'), mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) image_size = ( 1024, 1024, ) interval = 5000 load_from = None log_level = 'INFO' log_processor = dict(by_epoch=False, type='LogProcessor', window_size=50) max_iters = 666667 model = dict( backbone=dict( depth=50, frozen_stages=-1, init_cfg=dict(checkpoint='torchvision://resnet50', type='Pretrained'), norm_cfg=dict(requires_grad=False, type='BN'), norm_eval=True, num_stages=4, out_indices=( 0, 1, 2, 3, ), style='pytorch', type='ResNet'), data_preprocessor=dict( batch_augments=[ dict( img_pad_value=0, mask_pad_value=0, pad_mask=True, pad_seg=False, size=( 1024, 1024, ), type='BatchFixedSizePad'), ], bgr_to_rgb=True, mask_pad_value=0, mean=[ 123.675, 116.28, 103.53, ], pad_mask=True, pad_seg=False, pad_size_divisor=32, seg_pad_value=255, std=[ 58.395, 57.12, 57.375, ], type='DetDataPreprocessor'), init_cfg=None, panoptic_fusion_head=dict( init_cfg=None, loss_panoptic=None, num_stuff_classes=8, num_things_classes=0, type='MaskFormerFusionHead'), panoptic_head=dict( enforce_decoder_input_project=False, feat_channels=256, in_channels=[ 256, 512, 1024, 2048, ], loss_cls=dict( class_weight=[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ], loss_weight=2.0, reduction='mean', type='CrossEntropyLoss', use_sigmoid=False), loss_dice=dict( activate=True, eps=1.0, loss_weight=5.0, naive_dice=True, reduction='mean', type='DiceLoss', use_sigmoid=True), loss_mask=dict( loss_weight=5.0, reduction='mean', type='CrossEntropyLoss', use_sigmoid=True), num_queries=100, num_stuff_classes=8, num_things_classes=0, num_transformer_feat_level=3, out_channels=256, pixel_decoder=dict( act_cfg=dict(type='ReLU'), encoder=dict( layer_cfg=dict( ffn_cfg=dict( act_cfg=dict(inplace=True, type='ReLU'), embed_dims=256, feedforward_channels=1024, ffn_drop=0.0, num_fcs=2), self_attn_cfg=dict( batch_first=True, dropout=0.0, embed_dims=256, num_heads=8, num_levels=3, num_points=4)), num_layers=6), norm_cfg=dict(num_groups=32, type='GN'), num_outs=3, positional_encoding=dict(normalize=True, num_feats=128), type='MSDeformAttnPixelDecoder'), positional_encoding=dict(normalize=True, num_feats=128), strides=[ 4, 8, 16, 32, ], transformer_decoder=dict( init_cfg=None, layer_cfg=dict( cross_attn_cfg=dict( batch_first=True, dropout=0.0, embed_dims=256, num_heads=8), ffn_cfg=dict( act_cfg=dict(inplace=True, type='ReLU'), embed_dims=256, feedforward_channels=2048, ffn_drop=0.0, num_fcs=2), self_attn_cfg=dict( batch_first=True, dropout=0.0, embed_dims=256, num_heads=8)), num_layers=9, return_intermediate=True), type='Mask2FormerHead'), test_cfg=dict( filter_low_score=True, instance_on=True, iou_thr=0.8, max_per_image=100, panoptic_on=False, semantic_on=False), train_cfg=dict( assigner=dict( match_costs=[ dict(type='ClassificationCost', weight=2.0), dict( type='CrossEntropyLossCost', use_sigmoid=True, weight=5.0), dict(eps=1.0, pred_act=True, type='DiceCost', weight=5.0), ], type='HungarianAssigner'), importance_sample_ratio=0.75, num_points=12544, oversample_ratio=3.0, sampler=dict(type='MaskPseudoSampler')), type='Mask2Former') num_classes = 8 num_stuff_classes = 8 num_things_classes = 0 optim_wrapper = dict( clip_grad=dict(max_norm=0.01, norm_type=2), optimizer=dict( betas=( 0.9, 0.999, ), eps=1e-08, lr=0.0001, type='AdamW', weight_decay=0.05), paramwise_cfg=dict( custom_keys=dict( backbone=dict(decay_mult=1.0, lr_mult=0.1), level_embed=dict(decay_mult=0.0, lr_mult=1.0), query_embed=dict(decay_mult=0.0, lr_mult=1.0), query_feat=dict(decay_mult=0.0, lr_mult=1.0)), norm_decay_mult=0.0), type='OptimWrapper') param_scheduler = dict( begin=0, by_epoch=False, end=368750, gamma=0.1, milestones=[ 327778, 355092, ], type='MultiStepLR') load_from = 'work_dirs/custom_config/iter_26667.pth' resume=False test_cfg = dict(type='TestLoop') test_dataloader = dict( batch_size=1, dataset=dict( metainfo=dict(classes=classes), ann_file='coco_val_v4.json.', backend_args=None, data_prefix=dict(img='./images/'), data_root='./data', pipeline=[ dict(backend_args=None, to_float32=True, type='LoadImageFromFile'), dict(keep_ratio=True, scale=( 1333, 800, ), type='Resize'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict( meta_keys=( 'img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', ), type='PackDetInputs'), ], test_mode=True, type='CocoDataset'), drop_last=False, num_workers=2, persistent_workers=True, sampler=dict(shuffle=False, type='DefaultSampler')) test_evaluator = dict( ann_file='/coco_test_v4.json', backend_args=None, format_only=False, metric=[ 'bbox', 'segm', ], type='CocoMetric') test_pipeline = [ dict(backend_args=None, to_float32=True, type='LoadImageFromFile'), dict(keep_ratio=True, scale=( 1333, 800, ), type='Resize'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict( meta_keys=( 'img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', ), type='PackDetInputs'), ] train_cfg = dict( dynamic_intervals=[ ( 365001, 368750, ), ], max_iters=666667, type='IterBasedTrainLoop', val_interval=5000) train_dataloader = dict( batch_sampler=dict(type='AspectRatioBatchSampler'), batch_size=2, dataset=dict( metainfo=dict(classes=classes), ann_file='./coco_train_v4.json', backend_args=None, data_prefix=dict( img='./images/'), data_root='./data', filter_cfg=dict(filter_empty_gt=True, min_size=32), pipeline=[ dict(backend_args=None, to_float32=True, type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict(prob=0.5, type='RandomFlip'), dict( keep_ratio=True, ratio_range=( 0.1, 2.0, ), resize_type='Resize', scale=( 1024, 1024, ), type='RandomResize'), dict( allow_negative_crop=True, crop_size=( 1024, 1024, ), crop_type='absolute', recompute_bbox=True, type='RandomCrop'), dict( by_mask=True, min_gt_bbox_wh=( 1e-05, 1e-05, ), type='FilterAnnotations'), dict(type='PackDetInputs'), ], type='CocoDataset'), num_workers=2, persistent_workers=True, sampler=dict(shuffle=True, type='DefaultSampler')) train_pipeline = [ dict(backend_args=None, to_float32=True, type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict(prob=0.5, type='RandomFlip'), dict( keep_ratio=True, ratio_range=( 0.1, 2.0, ), resize_type='Resize', scale=( 1024, 1024, ), type='RandomResize'), dict( allow_negative_crop=True, crop_size=( 1024, 1024, ), crop_type='absolute', recompute_bbox=True, type='RandomCrop'), dict( by_mask=True, min_gt_bbox_wh=( 1e-05, 1e-05, ), type='FilterAnnotations'), dict(type='PackDetInputs'), ] val_cfg = dict(type='ValLoop') val_dataloader = dict( batch_size=1, dataset=dict( metainfo=dict(classes=classes), ann_file='coco_val_v4.json', backend_args=None, data_prefix=dict(img='./images/'), data_root='./data', pipeline=[ dict(backend_args=None, to_float32=True, type='LoadImageFromFile'), dict(keep_ratio=True, scale=( 1333, 800, ), type='Resize'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict( meta_keys=( 'img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', ), type='PackDetInputs'), ], test_mode=True, type='CocoDataset'), drop_last=False, num_workers=2, persistent_workers=True, sampler=dict(shuffle=False, type='DefaultSampler')) val_evaluator = dict( ann_file='./data/coco_val_v4.json', backend_args=None, format_only=False, metric=[ 'bbox', 'segm', ], type='CocoMetric') vis_backends = [ dict(type='LocalVisBackend'), ] visualizer = dict( name='visualizer', type='DetLocalVisualizer', vis_backends=[ dict(type='LocalVisBackend'), ])

I am training on 3 A100 gpus (But run into the same issue if I use only one), and using CUDA 12.0 version. I could not find any similar issues. Any help would be greatly appreciated.

xika1234 commented 5 months ago

i met the same problem, did you solve it?

chenxi52 commented 4 months ago

I met the same

ShenZheng2000 commented 3 months ago

Same problem.

SurfonL commented 3 months ago

this happens when you use IterBasedTrainLoop and you try to resume training. The simplest solution is to edit the source code of mmengine: ~/anaconda3/envs/sam/lib/python3.11/site-packages/mmengine/runner/loops.py and comment out the following lines:

if self._iter > 0:

    #     print_log(
    #         f'Advance dataloader {self._iter} steps to skip data '
    #         'that has already been trained',
    #         logger='current',
    #         level=logging.WARNING)
    #     for _ in range(self._iter):
    #         next(self.dataloader_iterator)
ajeetkverma commented 1 month ago

this happens when you use IterBasedTrainLoop and you try to resume training. The simplest solution is to edit the source code of mmengine: ~/anaconda3/envs/sam/lib/python3.11/site-packages/mmengine/runner/loops.py and comment out the following lines: # if self._iter > 0: # print_log( # f'Advance dataloader {self.iter} steps to skip data ' # 'that has already been trained', # logger='current', # level=logging.WARNING) # for in range(self._iter): # next(self.dataloader_iterator)

It worked... Thank you.