LiheYoung / UniMatch

[CVPR 2023] Revisiting Weak-to-Strong Consistency in Semi-Supervised Semantic Segmentation
https://arxiv.org/abs/2208.09910
MIT License
476 stars 59 forks source link

Questions for prediction mode #76

Closed DeepHM closed 1 year ago

DeepHM commented 1 year ago

I have a question about the prediction mode for pseudo-labels. All my questions are related to "unimatch.py".

[ Question 1 ] Your approach is: 1) The "fixmatch" study, which widely announced the success of the semi-supervised study, also performs very well in semantic segmentation. 2) Therefore, if the perturbed space is better utilized, performance can be improved in a simple way. Is this right?

[ Question 2 ] Looking at the "unmatch.py" code, 1) Labeled loss is calculated between labeled-logits from model forward and labeled images. 2) Loss_u_s1, which is unlabeled loss1, is calculated for pred_u_s1 and mask_u_w_cutmixed1. => pred_u_s1 is the out-logits of cutmixed-strong-unlabeled-images1, and mask_u_w_cutmixed1 is a mixture of the mask result of model forward for weakly-unlabeled-images1 and the mask result of torch.no_grad() for weakly-unlabeled-images2. => Loss_u_s2 is similar. (on the other way) 3) loss_u_w_fp(unlabeled fp loss) is calculated for pred_u_w_fp and mask_u_w. => pred_u_w_fp : Features for model forward on weakly-unlabeled-images1(img_u_w) => mask_u_w : Mask prediction for model forward on weakly-unlabeled-images1(img_u_w)

Here, what I'm confused about is: 1) is understandable, but In the pseudo-labels of 2), why are they created based on a mixture of model forward and torch.no_grad? And in pseduo-labels in 3), why is the mixture of the results from model forward used, rather than the mixture of results from torch.no_grad?

LiheYoung commented 1 year ago

The torch.no_grad() is not necessary. In fact, you can directly forward the images by

preds, preds_fp = model(torch.cat((img_x, img_u_w, img_u_w_mix)), True)

, and then obtain your pred_u_w, pred_u_w_mix from preds.

The most important thing is that you should set model.train() when forwarding the img_u_w to obtain clean and informative BN statistics, see here.

The torch.no_grad() is just to speed up the pseudo labeling process for img_u_w_mix. Besides, we set model.eval() for it because we have obtained informative BN from the img_u_w.

LiheYoung commented 1 year ago

Closed for inactivity.