SysCV / sam-hq

Segment Anything in High Quality [NeurIPS 2023]
https://arxiv.org/abs/2306.01567
Apache License 2.0
3.73k stars 224 forks source link

The fine-tuned HQ-SAM model shows a significant improvement in accuracy. However, the val_iou_0 accuracy of HQ-SAM during training is very low. #73

Closed geoexploring closed 1 year ago

geoexploring commented 1 year ago

Thank you for your excellent work. I have gained a lot of inspiration.

After fine-tuning my downstream task data with HQ-SAM, there has been a significant improvement in the accuracy of HQ-SAM.

However, the val_iou_0 accuracy of HQ-SAM during training is very low (During training, there is only one object in each image), even when the training epoch is increased to 120. May I know the reason for this?

Below is the detailed training log:

args: Namespace(output='/home/user/Desktop/sam_hq/output', model_type='vit_h', checkpoint='./pretrained_checkpoint/sam_vit_h_4b8939.pth', device='cuda', seed=42, learning_rate=0.001, start_epoch=0, lr_drop_epoch=10, max_epoch_num=12, input_size=[650, 1250], batch_size_train=4, batch_size_valid=1, model_save_fre=1, world_size=2, dist_url='env://', rank=0, local_rank=0, find_unused_params=False, eval=False, visualize=False, restore_model=None, gpu=0, distributed=True, dist_backend='nccl')

--- create training dataloader ---
------------------------------ train --------------------------------
--->>> train  dataset  0 / 1   M4D <<<---
-im- M4D /home/user/Desktop/sam_hq/sam_hq_m4d/train/images :  1795
-gt- M4D /home/user/Desktop/sam_hq/sam_hq_m4d/train/labels :  1795
224  train dataloaders created
--- create valid dataloader ---
------------------------------ valid --------------------------------
--->>> valid  dataset  0 / 1   M4D <<<---
-im- M4D /home/user/Desktop/sam_hq/sam_hq_m4d/val/images :  203
-gt- M4D /home/user/Desktop/sam_hq/sam_hq_m4d/val/labels :  203
1  valid dataloaders created
--- define optimizer ---
epoch:    0   learning rate:   0.001
/home/user/Desktop/sam_hq/sam_hq_Env/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/user/Desktop/sam_hq/sam_hq_Env/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  [  0/224]  eta: 0:20:48  training_loss: 1.6844 (1.6844)  loss_dice: 0.9837 (0.9837)  loss_mask: 0.7007 (0.7007)  time: 5.5757  data: 0.8063  max mem: 18409
  [223/224]  eta: 0:00:02  training_loss: 0.5101 (0.5171)  loss_dice: 0.3813 (0.3909)  loss_mask: 0.1239 (0.1262)  time: 2.9547  data: 0.0031  max mem: 18811
 Total time: 0:10:42 (2.8689 s / it)
Finished epoch:       0
Averaged stats: training_loss: 0.5101 (0.5171)  loss_dice: 0.3813 (0.3909)  loss_mask: 0.1239 (0.1262)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:42  val_boundary_iou_0: 0.0437 (0.0437)  val_iou_0: 0.0437 (0.0437)  time: 1.0055  data: 0.2730  max mem: 18811
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0057 (0.0293)  val_iou_0: 0.0057 (0.0365)  time: 0.7579  data: 0.0025  max mem: 18811
 Total time: 0:01:17 (0.7625 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0057 (0.0293)  val_iou_0: 0.0057 (0.0365)
come here save at /home/user/Desktop/sam_hq/output/epoch_0.pth
epoch:    1   learning rate:   0.001
  [  0/224]  eta: 0:13:13  training_loss: 0.4415 (0.4415)  loss_dice: 0.3727 (0.3727)  loss_mask: 0.0688 (0.0688)  time: 3.5415  data: 0.6857  max mem: 18811
  [223/224]  eta: 0:00:02  training_loss: 0.4284 (0.4301)  loss_dice: 0.3158 (0.3319)  loss_mask: 0.0978 (0.0982)  time: 2.9993  data: 0.0035  max mem: 18812
 Total time: 0:11:05 (2.9731 s / it)
Finished epoch:       1
Averaged stats: training_loss: 0.4284 (0.4301)  loss_dice: 0.3158 (0.3319)  loss_mask: 0.0978 (0.0982)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:44  val_boundary_iou_0: 0.0368 (0.0368)  val_iou_0: 0.0368 (0.0368)  time: 1.0216  data: 0.2634  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0059 (0.0292)  val_iou_0: 0.0059 (0.0355)  time: 0.7731  data: 0.0026  max mem: 18812
 Total time: 0:01:19 (0.7752 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0059 (0.0292)  val_iou_0: 0.0059 (0.0355)
come here save at /home/user/Desktop/sam_hq/output/epoch_1.pth
epoch:    2   learning rate:   0.001
  [  0/224]  eta: 0:13:50  training_loss: 0.3029 (0.3029)  loss_dice: 0.2903 (0.2903)  loss_mask: 0.0126 (0.0126)  time: 3.7083  data: 0.6804  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3450 (0.4276)  loss_dice: 0.2690 (0.3291)  loss_mask: 0.0837 (0.0985)  time: 2.9533  data: 0.0034  max mem: 18812
 Total time: 0:11:03 (2.9628 s / it)
Finished epoch:       2
Averaged stats: training_loss: 0.3450 (0.4276)  loss_dice: 0.2690 (0.3291)  loss_mask: 0.0837 (0.0985)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:41  val_boundary_iou_0: 0.0373 (0.0373)  val_iou_0: 0.0373 (0.0373)  time: 0.9920  data: 0.2632  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0083 (0.0336)  val_iou_0: 0.0083 (0.0433)  time: 0.7544  data: 0.0028  max mem: 18812
 Total time: 0:01:17 (0.7592 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0083 (0.0336)  val_iou_0: 0.0083 (0.0433)
come here save at /home/user/Desktop/sam_hq/output/epoch_2.pth
epoch:    3   learning rate:   0.001
  [  0/224]  eta: 0:13:12  training_loss: 0.3277 (0.3277)  loss_dice: 0.2739 (0.2739)  loss_mask: 0.0538 (0.0538)  time: 3.5375  data: 0.7209  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3417 (0.4075)  loss_dice: 0.2744 (0.3187)  loss_mask: 0.0470 (0.0888)  time: 2.9781  data: 0.0028  max mem: 18812
 Total time: 0:11:04 (2.9660 s / it)
Finished epoch:       3
Averaged stats: training_loss: 0.3417 (0.4075)  loss_dice: 0.2744 (0.3187)  loss_mask: 0.0470 (0.0888)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:47  val_boundary_iou_0: 0.0400 (0.0400)  val_iou_0: 0.0400 (0.0400)  time: 1.0492  data: 0.2645  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0036 (0.0308)  val_iou_0: 0.0036 (0.0381)  time: 0.7657  data: 0.0026  max mem: 18812
 Total time: 0:01:18 (0.7665 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0036 (0.0308)  val_iou_0: 0.0036 (0.0381)
come here save at /home/user/Desktop/sam_hq/output/epoch_3.pth
epoch:    4   learning rate:   0.001
  [  0/224]  eta: 0:14:13  training_loss: 0.4991 (0.4991)  loss_dice: 0.4048 (0.4048)  loss_mask: 0.0943 (0.0943)  time: 3.8119  data: 0.6596  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3842 (0.4097)  loss_dice: 0.3091 (0.3169)  loss_mask: 0.0534 (0.0927)  time: 2.9631  data: 0.0032  max mem: 18812
 Total time: 0:11:05 (2.9699 s / it)
Finished epoch:       4
Averaged stats: training_loss: 0.3842 (0.4097)  loss_dice: 0.3091 (0.3169)  loss_mask: 0.0534 (0.0927)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:48  val_boundary_iou_0: 0.0371 (0.0371)  val_iou_0: 0.0371 (0.0371)  time: 1.0607  data: 0.3160  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0041 (0.0297)  val_iou_0: 0.0041 (0.0364)  time: 0.7597  data: 0.0022  max mem: 18812
 Total time: 0:01:18 (0.7653 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0041 (0.0297)  val_iou_0: 0.0041 (0.0364)
come here save at /home/user/Desktop/sam_hq/output/epoch_4.pth
epoch:    5   learning rate:   0.001
  [  0/224]  eta: 0:13:33  training_loss: 0.3959 (0.3959)  loss_dice: 0.3118 (0.3118)  loss_mask: 0.0841 (0.0841)  time: 3.6311  data: 0.8202  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3818 (0.4074)  loss_dice: 0.2736 (0.3169)  loss_mask: 0.0753 (0.0906)  time: 2.9842  data: 0.0032  max mem: 18812
 Total time: 0:11:08 (2.9862 s / it)
Finished epoch:       5
Averaged stats: training_loss: 0.3818 (0.4074)  loss_dice: 0.2736 (0.3169)  loss_mask: 0.0753 (0.0906)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:46  val_boundary_iou_0: 0.0399 (0.0399)  val_iou_0: 0.0399 (0.0399)  time: 1.0399  data: 0.2330  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0044 (0.0295)  val_iou_0: 0.0044 (0.0366)  time: 0.7589  data: 0.0025  max mem: 18812
 Total time: 0:01:17 (0.7629 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0044 (0.0295)  val_iou_0: 0.0044 (0.0366)
come here save at /home/user/Desktop/sam_hq/output/epoch_5.pth
epoch:    6   learning rate:   0.001
  [  0/224]  eta: 0:13:51  training_loss: 0.4078 (0.4078)  loss_dice: 0.3496 (0.3496)  loss_mask: 0.0582 (0.0582)  time: 3.7106  data: 0.5695  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3825 (0.4001)  loss_dice: 0.3177 (0.3112)  loss_mask: 0.0755 (0.0889)  time: 2.9607  data: 0.0029  max mem: 18812
 Total time: 0:11:04 (2.9659 s / it)
Finished epoch:       6
Averaged stats: training_loss: 0.3825 (0.4001)  loss_dice: 0.3177 (0.3112)  loss_mask: 0.0755 (0.0889)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:42  val_boundary_iou_0: 0.0385 (0.0385)  val_iou_0: 0.0385 (0.0385)  time: 1.0051  data: 0.2670  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0039 (0.0285)  val_iou_0: 0.0039 (0.0357)  time: 0.7552  data: 0.0026  max mem: 18812
 Total time: 0:01:17 (0.7637 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0039 (0.0285)  val_iou_0: 0.0039 (0.0357)
come here save at /home/user/Desktop/sam_hq/output/epoch_6.pth
epoch:    7   learning rate:   0.001
  [  0/224]  eta: 0:13:11  training_loss: 0.4397 (0.4397)  loss_dice: 0.3814 (0.3814)  loss_mask: 0.0583 (0.0583)  time: 3.5336  data: 0.5829  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3527 (0.3893)  loss_dice: 0.2628 (0.3036)  loss_mask: 0.0801 (0.0857)  time: 2.9578  data: 0.0032  max mem: 18812
 Total time: 0:11:06 (2.9772 s / it)
Finished epoch:       7
Averaged stats: training_loss: 0.3527 (0.3893)  loss_dice: 0.2628 (0.3036)  loss_mask: 0.0801 (0.0857)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:42  val_boundary_iou_0: 0.0392 (0.0392)  val_iou_0: 0.0392 (0.0392)  time: 1.0091  data: 0.2252  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0042 (0.0313)  val_iou_0: 0.0042 (0.0392)  time: 0.7551  data: 0.0025  max mem: 18812
 Total time: 0:01:17 (0.7601 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0042 (0.0313)  val_iou_0: 0.0042 (0.0392)
come here save at /home/user/Desktop/sam_hq/output/epoch_7.pth
epoch:    8   learning rate:   0.001
  [  0/224]  eta: 0:13:50  training_loss: 0.5195 (0.5195)  loss_dice: 0.3575 (0.3575)  loss_mask: 0.1620 (0.1620)  time: 3.7089  data: 0.7330  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3803 (0.4112)  loss_dice: 0.2887 (0.3164)  loss_mask: 0.0877 (0.0948)  time: 2.9565  data: 0.0029  max mem: 18812
 Total time: 0:11:04 (2.9656 s / it)
Finished epoch:       8
Averaged stats: training_loss: 0.3803 (0.4112)  loss_dice: 0.2887 (0.3164)  loss_mask: 0.0877 (0.0948)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:45  val_boundary_iou_0: 0.0403 (0.0403)  val_iou_0: 0.0403 (0.0403)  time: 1.0304  data: 0.2850  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0039 (0.0298)  val_iou_0: 0.0039 (0.0364)  time: 0.7571  data: 0.0025  max mem: 18812
 Total time: 0:01:17 (0.7609 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0039 (0.0298)  val_iou_0: 0.0039 (0.0364)
come here save at /home/user/Desktop/sam_hq/output/epoch_8.pth
epoch:    9   learning rate:   0.001
  [  0/224]  eta: 0:13:27  training_loss: 0.3432 (0.3432)  loss_dice: 0.2848 (0.2848)  loss_mask: 0.0584 (0.0584)  time: 3.6029  data: 0.5989  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3879 (0.3944)  loss_dice: 0.3550 (0.3080)  loss_mask: 0.0764 (0.0864)  time: 2.9516  data: 0.0034  max mem: 18812
 Total time: 0:11:02 (2.9597 s / it)
Finished epoch:       9
Averaged stats: training_loss: 0.3879 (0.3944)  loss_dice: 0.3550 (0.3080)  loss_mask: 0.0764 (0.0864)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:43  val_boundary_iou_0: 0.0454 (0.0454)  val_iou_0: 0.0454 (0.0454)  time: 1.0126  data: 0.2492  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0038 (0.0305)  val_iou_0: 0.0038 (0.0373)  time: 0.7546  data: 0.0025  max mem: 18812
 Total time: 0:01:17 (0.7596 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0038 (0.0305)  val_iou_0: 0.0038 (0.0373)
come here save at /home/user/Desktop/sam_hq/output/epoch_9.pth
epoch:    10   learning rate:   0.0001
  [  0/224]  eta: 0:13:33  training_loss: 0.2689 (0.2689)  loss_dice: 0.2421 (0.2421)  loss_mask: 0.0268 (0.0268)  time: 3.6322  data: 0.7611  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3277 (0.3679)  loss_dice: 0.2293 (0.2891)  loss_mask: 0.0631 (0.0789)  time: 2.9478  data: 0.0033  max mem: 18812
 Total time: 0:11:01 (2.9525 s / it)
Finished epoch:       10
Averaged stats: training_loss: 0.3277 (0.3679)  loss_dice: 0.2293 (0.2891)  loss_mask: 0.0631 (0.0789)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:42  val_boundary_iou_0: 0.0490 (0.0490)  val_iou_0: 0.0490 (0.0490)  time: 1.0006  data: 0.2471  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0041 (0.0306)  val_iou_0: 0.0041 (0.0380)  time: 0.7537  data: 0.0024  max mem: 18812
 Total time: 0:01:17 (0.7603 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0041 (0.0306)  val_iou_0: 0.0041 (0.0380)
come here save at /home/user/Desktop/sam_hq/output/epoch_10.pth
epoch:    11   learning rate:   0.0001
  [  0/224]  eta: 0:13:43  training_loss: 0.3107 (0.3107)  loss_dice: 0.2759 (0.2759)  loss_mask: 0.0348 (0.0348)  time: 3.6765  data: 0.7278  max mem: 18812
  [223/224]  eta: 0:00:02  training_loss: 0.3138 (0.3660)  loss_dice: 0.2551 (0.2876)  loss_mask: 0.0517 (0.0784)  time: 2.9423  data: 0.0031  max mem: 18812
 Total time: 0:11:00 (2.9502 s / it)
Finished epoch:       11
Averaged stats: training_loss: 0.3138 (0.3660)  loss_dice: 0.2551 (0.2876)  loss_mask: 0.0517 (0.0784)
Validating...
valid_dataloader len: 102
  [  0/102]  eta: 0:01:41  val_boundary_iou_0: 0.0517 (0.0517)  val_iou_0: 0.0517 (0.0517)  time: 0.9966  data: 0.2456  max mem: 18812
  [101/102]  eta: 0:00:00  val_boundary_iou_0: 0.0040 (0.0308)  val_iou_0: 0.0040 (0.0381)  time: 0.7641  data: 0.0025  max mem: 18812
 Total time: 0:01:18 (0.7650 s / it)
============================
Averaged stats: val_boundary_iou_0: 0.0040 (0.0308)  val_iou_0: 0.0040 (0.0381)
come here save at /home/user/Desktop/sam_hq/output/epoch_11.pth
Training Reaches The Maximum Epoch Number

Thanks.

ymq2017 commented 1 year ago

Hi, if the visualization result is correct but the IoU metric is strange, this may be because the format of GT or prediction during testing is incorrect. You can check the input when testing the metric. For example, we compute IoU this line. masks_hq has a range of (-inf,+inf), labels_ori has a range of [0,255]. You can check the range of some variables, or visualize some intermediate variables to determine whether there is a problem with the format.

geoexploring commented 1 year ago

@ymq2017 , Thanks.

I found that the key is in the line, which states that the input image dimensions must be [1024,1024]. If I modify the "input_size" to any other value, the IoU will be incorrect.