LiheYoung / UniMatch

[CVPR 2023] Revisiting Weak-to-Strong Consistency in Semi-Supervised Semantic Segmentation
https://arxiv.org/abs/2208.09910
MIT License
483 stars 60 forks source link

Ask questions #19

Closed oiorrr closed 1 year ago

oiorrr commented 1 year ago

1.The first question is:

example: loss_u_s2 = criterion_u(pred_u_s2, mask_u_w_cutmixed2) loss_u_s2 = loss_u_s2 * ((conf_u_w_cutmixed2 >= cfg['conf_thresh']) & (ignore_mask_cutmixed2 != 255)) loss_u_s2 = torch.sum(loss_u_s2) / torch.sum(ignore_mask_cutmixed2 != 255).item()

Take the above loss as an example: we can see that after the unsupervised loss is calculated, the loss_item is filtered according to two conditions (the second line); Then on the third line, we see that there is a division operation (can it be understood as a normalization of weight? ? Why did the denominator here become? ignore_mask ! = 255 This conditional region.

2.The second question is: Because you are aiming at the direction of semantic segmentation in this article, you have adopted two conditions: argmax () and confidence threshold greater than 0.95 when setting false labels for unlabeled data! Since my current research direction does not involve multiple classes and belongs to the task of binary classification, it implicitly shows that 0.5 becomes the first condition, so if I want to learn from your idea, do I still need a threshold of 0.95?

3.The third question: The third question may be related to the second question. If I don't use 0.95 as the threshold, can the weight of the unlabeled loss be fixed at 0.5 or 0.25? Do you think there are other good ways?

LiheYoung commented 1 year ago

Thank you for raising these questions.

  1. Among the two filtering conditions, conf_u_w_cutmixed2 >= cfg['conf_thresh'] is for selecting high-confidence pseudo labels, while ignore_mask_cutmixed2 != 255 is to avoid the loss computation in padded regions from image pre-processing. And we only divide the second conditiontorch.sum(ignore_mask_cutmixed2 != 255).item() in the third line, because we hope ((conf_u_w_cutmixed2 >= cfg['conf_thresh']) & (ignore_mask_cutmixed2 != 255)).sum() / (ignore_mask_cutmixed2 != 255).sum() can serve as an adaptive weight for our unsupervised loss. Concretely, at earlier training stages, the adaptive weight will be small, due to the abundant low-confidence pixels. Our model is mainly learned with labeled images. Then as training proceeds, the adaptive weight will gradually increase to learn more on unlabeled images. You may also refer to the Equation 2 in FixMatch.

  2. I think that either 0.5 or 0.95 is just a hyperparameter, regardless of the number of classes. In your case of binary classification, I guess you may use a sigmoid function for classification. Then you can still maintain a 0.95 threshold by only considering the output > 0.95 (certainly class 1) and the output < 0.05 (certainly class 0). The 0.5 you mentioned probably denotes the classification threshold for sigmoid, rather than our confidence threshold.

  3. I think the unsupervised loss weight can be adjusted according to your observations. It may depend on the proportion and hardness of unlabeled images.