Closed DeepHM closed 1 year ago
The torch.no_grad()
is not necessary. In fact, you can directly forward the images by
preds, preds_fp = model(torch.cat((img_x, img_u_w, img_u_w_mix)), True)
, and then obtain your pred_u_w
, pred_u_w_mix
from preds
.
The most important thing is that you should set model.train()
when forwarding the img_u_w
to obtain clean and informative BN statistics, see here.
The torch.no_grad()
is just to speed up the pseudo labeling process for img_u_w_mix
. Besides, we set model.eval()
for it because we have obtained informative BN from the img_u_w
.
Closed for inactivity.
I have a question about the prediction mode for pseudo-labels. All my questions are related to "unimatch.py".
[ Question 1 ] Your approach is: 1) The "fixmatch" study, which widely announced the success of the semi-supervised study, also performs very well in semantic segmentation. 2) Therefore, if the perturbed space is better utilized, performance can be improved in a simple way. Is this right?
[ Question 2 ] Looking at the "unmatch.py" code, 1) Labeled loss is calculated between labeled-logits from model forward and labeled images. 2) Loss_u_s1, which is unlabeled loss1, is calculated for pred_u_s1 and mask_u_w_cutmixed1. => pred_u_s1 is the out-logits of cutmixed-strong-unlabeled-images1, and mask_u_w_cutmixed1 is a mixture of the mask result of model forward for weakly-unlabeled-images1 and the mask result of torch.no_grad() for weakly-unlabeled-images2. => Loss_u_s2 is similar. (on the other way) 3) loss_u_w_fp(unlabeled fp loss) is calculated for pred_u_w_fp and mask_u_w. => pred_u_w_fp : Features for model forward on weakly-unlabeled-images1(img_u_w) => mask_u_w : Mask prediction for model forward on weakly-unlabeled-images1(img_u_w)
Here, what I'm confused about is: 1) is understandable, but In the pseudo-labels of 2), why are they created based on a mixture of model forward and torch.no_grad? And in pseduo-labels in 3), why is the mixture of the results from model forward used, rather than the mixture of results from torch.no_grad?