MengyuWang826 / SegRefiner

SegRefiner: Towards Model-Agnostic Segmentation Refinement with Discrete Diffusion Process
Apache License 2.0
137 stars 8 forks source link

I got an strange result from the training with HR dataset #7

Open yusuke-ai opened 5 months ago

yusuke-ai commented 5 months ago

Hi,

Thank you for the awesome work! I trained the model with the HR dataset with the almost same configuration as your code and with the command below. The only difference is the learning rate and it is set to 4e-5 and got models both in iteration=10000 and iteration=around 60000. (I didn't use the same training rate with the your code because it sometimes jumps the loss while training. I will comment more in #4)

python tools/train.py configs/segrefiner/segrefiner_hr.py --resume-from segrefiner_hr_latest.pth

and I got the different refinement results. The original segrefiner_hr_latest.pth model has a smooth segmentation around the line, but the retrained model has a jaggy segmentation around the line like below. Expected result is the model shouldn't output differently.

Could you help me with finding the core issue? Thank you!

Segmentation result with the segrefiner_hr_latest.pth

Screenshot from 2024-01-19 11-10-58

Segmentation result with the retrained model

Screenshot from 2024-01-19 11-13-11

yusuke-ai commented 5 months ago

@MengyuWang826 This is the configuration of the training just in case.

checkpoint_config = dict(
    interval=5000, by_epoch=False, save_last=True, max_keep_ckpts=20)
log_config = dict(
    interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = 'segrefiner_hr_latest.pth'
workflow = [('train', 5000)]
opencv_num_threads = 0
mp_start_method = 'fork'
auto_scale_lr = dict(enable=False, base_batch_size=16)
object_size = 256
task = 'instance'
model = dict(
    type='SegRefiner',
    task='instance',
    step=6,
    denoise_model=dict(
        type='DenoiseUNet',
        in_channels=4,
        out_channels=1,
        model_channels=128,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        attention_strides=(16, 32),
        learn_time_embd=True,
        channel_mult=(1, 1, 2, 2, 4, 4),
        dropout=0.0),
    diffusion_cfg=dict(
        betas=dict(type='linear', start=0.8, stop=0, num_timesteps=6),
        diff_iter=False),
    test_cfg=dict())
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='LoadAnnotations',
        with_bbox=False,
        with_label=False,
        with_mask=True),
    dict(type='LoadPatchData', object_size=256, patch_size=256),
    dict(type='Resize', img_scale=(256, 256), keep_ratio=False),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='DefaultFormatBundle'),
    dict(
        type='Collect',
        keys=[
            'object_img', 'object_gt_masks', 'object_coarse_masks',
            'patch_img', 'patch_gt_masks', 'patch_coarse_masks'
        ])
]
dataset_type = 'HRCollectionDataset'
img_root = '/share/project/datasets/MSCOCO/coco2017/'
ann_root = '/share/project/datasets/LVIS/'
train_dataloader = dict(samples_per_gpu=4, workers_per_gpu=1)
data = dict(
    train=dict(
        type='HRCollectionDataset',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='LoadAnnotations',
                with_bbox=False,
                with_label=False,
                with_mask=True),
            dict(type='LoadPatchData', object_size=256, patch_size=256),
            dict(type='Resize', img_scale=(256, 256), keep_ratio=False),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='DefaultFormatBundle'),
            dict(
                type='Collect',
                keys=[
                    'object_img', 'object_gt_masks', 'object_coarse_masks',
                    'patch_img', 'patch_gt_masks', 'patch_coarse_masks'
                ])
        ],
        data_root='data/',
        collection_datasets=['thin', 'dis'],
        collection_json='data/collection_hr.json'),
    train_dataloader=dict(samples_per_gpu=4, workers_per_gpu=1),
    val=dict(),
    test=dict())
optimizer = dict(
    type='AdamW', lr=0.00004, weight_decay=0, eps=1e-08, betas=(0.9, 0.999))
optimizer_config = dict(grad_clip=None)
max_iters = 120000
runner = dict(type='IterBasedRunner', max_iters=120000)
lr_config = dict(
    policy='step',
    gamma=0.5,
    by_epoch=False,
    step=[80000, 100000],
    warmup='linear',
    warmup_by_epoch=False,
    warmup_ratio=1.0,
    warmup_iters=10)
interval = 5000
data_root = 'data/'
work_dir = './work_dirs/segrefiner_hr'
auto_resume = False
gpu_ids = [0]
MengyuWang826 commented 5 months ago

@yusuke-ai It appears that in this sample, jaggy segmentation occurs because some regions undergo only global refinement without local refinement. Taking the relevant parameters for collecting local patches in the segrefiner_big.py config as an example: model = dict( type='SegRefinerSemantic', task=task, test_cfg=dict( model_size=object_size, fine_prob_thr=0.9, iou_thr=0.3, batch_max=32)) You can try reducing the fine_prob_thr to ensure that more local patches are collected.

MengyuWang826 commented 5 months ago

@yusuke-ai It appears that in this sample, jaggy segmentation occurs because some regions undergo only global refinement without local refinement. Taking the relevant parameters for collecting local patches in the segrefiner_big.py config as an example: model = dict( type='SegRefinerSemantic', task=task, test_cfg=dict( model_size=object_size, fine_prob_thr=0.9, iou_thr=0.3, batch_max=32)) You can try reducing the fine_prob_thr to ensure that more local patches are collected.

Alternatively, you can also try implementing a different method for collecting local patches, such as along the edges of the mask. The specific implementation of this step only affects which local patches will be refined and does not impact the functioning of the model.

yusuke-ai commented 5 months ago

@MengyuWang826 Thank you for the answer! I tried to change the fine_prob_thr to from 0.1 to 1.0, but I got similar jaggy results. Do you come up with other reasons of jaggy result?

MengyuWang826 commented 5 months ago

@yusuke-ai

image

Since not all positions exhibit jaggy segmentation, I speculate that the areas within the red box represent the normal output, and the appearance of jaggy segmentation is due to not being subjected to local refinement. You can start by visualizing to determine which local patches have undergone local refinement.

yusuke-ai commented 5 months ago

@MengyuWang826 Thank you! I played with the local refinement code and I concluded that the model I trained has a lower capability than the one you provided. I will try more with the training, but if you can retrain the model with the code in this repository again, it will be really helpful.

wzx0720 commented 5 months ago

@yusuke-ai Hey! I also try to train the HR-SegRefiner with the pretrained model segrefiner_hr_latest.pth. But I met this problem: self._epoch = checkpoint['meta']['epoch'] KeyError: 'meta'` I wonder whether you have met problem like this and how you solve it? Thanks a lot!

yusuke-ai commented 5 months ago

@wzx0720 Sorry for late. I just edited the lines around that line and just let them start from epoch 0.

wzx0720 commented 5 months ago

@wzx0720 Sorry for late. I just edited the lines around that line and just let them start from epoch 0.

Sorry for bothering you again. I can't understand how to edit the lines. Do you mean to edit the segerefiner_hr_latest.pth or to edit the .py file? May I see how you "edit the lines"? Thanks a lot again!!!

wang21jun commented 2 months ago

same problem, have you solved it?

wzx0720 commented 2 months ago

same problem, have you solved it?

@wang21jun I haven't try but maybe you could modify the load_from in configs/_base_/default_runtime.py to fine-tuning the model

wang21jun commented 2 months ago

same problem, have you solved it?

@wang21jun I haven't try but maybe you could modify the load_from in configs/_base_/default_runtime.py to fine-tuning the model

Thinks, I will try it.

wang21jun commented 2 months ago

same problem, have you solved it?

@wang21jun I haven't try but maybe you could modify the load_from in configs/_base_/default_runtime.py to fine-tuning the model

Thinks, I will try it. load from segrefiner_lr_latest.pth and train on DIS5K, there are still jaggy. image

wzx0720 commented 2 months ago

same problem, have you solved it?

@wang21jun I haven't try but maybe you could modify the load_from in configs/_base_/default_runtime.py to fine-tuning the model

Thinks, I will try it. load from segrefiner_lr_latest.pth and train on DIS5K, there are still jaggy. image

Yes, I had the same problem in my work so I quit this method now. Maybe you can open an issue and ask the author, I think.