SysCV / sam-hq

Segment Anything in High Quality [NeurIPS 2023]
https://arxiv.org/abs/2306.01567
Apache License 2.0
3.73k stars 224 forks source link

Training problem: During the training of sam-hq, the iou output of the val set is very high, 0.98; but in eval mode, the iou of the val set is only 0.48 #124

Open YUANMU227 opened 8 months ago

YUANMU227 commented 8 months ago

problem

During the training of sam-hq, the iou output of the val set is very high, 0.98; but in eval mode, the iou of the val set is only 0.48. The sam-hq that continues to be trained under this code is very poor and does not have good segmentation capabilities. What is the reason?

training instructions

python -m torch.distributed.launch --nproc_per_node=1 train.py --checkpoint ./pretrained_checkpoint/sam_vit_h_4b8939.pth --model-type vit_h --output work_dirs/hq_sam_h

training log

... epoch: 10 learning rate: 0.0001
[ 0/500] eta: 0:22:06 training_loss: 0.1545 (0.1545) loss_mask: 0.0859 (0.0859) loss_dice: 0.0686 (0.0686) time: 2.6534 data: 0.5714 max mem: 18844 [499/500] eta: 0:00:02 training_loss: 0.1074 (0.1478) loss_mask: 0.0694 (0.0956) loss_dice: 0.0350 (0.0522) time: 2.0849 data: 0.0027 max mem: 18844 Total time: 0:17:22 (2.0846 s / it)
Finished epoch: 10
Averaged stats: training_loss: 0.1074 (0.1478) loss_mask: 0.0694 (0.0956) loss_dice: 0.0350 (0.0522) Validating...
valid_dataloader len: 400
[ 0/400] eta: 0:04:56 val_iou_0: 0.9876 (0.9876) val_boundary_iou_0: 0.9019 (0.9019) time: 0.7404 data: 0.2156 max mem: 18844 [399/400] eta: 0:00:00 val_iou_0: 0.9923 (0.9837) val_boundary_iou_0: 0.8787 (0.8512) time: 0.5078 data: 0.0020 max mem: 18844 Total time: 0:03:25 (0.5142 s / it)
============================
Averaged stats: val_iou_0: 0.9923 (0.9837) val_boundary_iou_0: 0.8787 (0.8512)
come here save at work_dirs/hq_sam_h/epoch_10.pth
epoch: 11 learning rate: 0.0001
[ 0/500] eta: 0:21:31 training_loss: 0.0657 (0.0657) loss_mask: 0.0441 (0.0441) loss_dice: 0.0216 (0.0216) time: 2.5827 data: 0.4712 max mem: 18844 [499/500] eta: 0:00:02 training_loss: 0.0918 (0.1338) loss_mask: 0.0584 (0.0877) loss_dice: 0.0332 (0.0462) time: 2.0814 data: 0.0027 max mem: 18844 Total time: 0:17:20 (2.0811 s / it)
Finished epoch: 11
Averaged stats: training_loss: 0.0918 (0.1338) loss_mask: 0.0584 (0.0877) loss_dice: 0.0332 (0.0462) Validating...
valid_dataloader len: 400
[ 0/400] eta: 0:04:52 val_iou_0: 0.9867 (0.9867) val_boundary_iou_0: 0.9143 (0.9143) time: 0.7306 data: 0.2118 max mem: 18844 [399/400] eta: 0:00:00 val_iou_0: 0.9924 (0.9841) val_boundary_iou_0: 0.9008 (0.8570) time: 0.5099 data: 0.0020 max mem: 18844 Total time: 0:03:25 (0.5140 s / it)
============================
Averaged stats: val_iou_0: 0.9924 (0.9841) val_boundary_iou_0: 0.9008 (0.8570)
come here save at work_dirs/hq_sam_h/epoch_11.pth
Training Reaches The Maximum Epoch Number

evaluation instructions

python -m torch.distributed.launch --nproc_per_node=1 train.py --checkpoint ./pretrained_checkpoint/sam_vit_h_4b8939.pth --model-type vit_h --output work_dirs/hq_sam_h_eval --eval --restore-model work_dirs/hq_sam_h/epoch_11.pth

evaluation log

restore model from: work_dirs/hq_sam_h/epoch_11.pth Validating... valid_dataloader len: 400 [ 0/400] eta: 0:13:26 val_iou_0: 0.5542 (0.5542) val_boundary_iou_0: 0.2936 (0.2936) time: 2.0170 data: 0.1741 max mem: 8293 [399/400] eta: 0:00:00 val_iou_0: 0.4820 (0.4921) val_boundary_iou_0: 0.2288 (0.2294) time: 0.5036 data: 0.0021 max mem: 8741 Total time: 0:03:22 (0.5073 s / it)