googleinterns / wss

A simple consistency training framework for semi-supervised image semantic segmentation
Apache License 2.0
164 stars 24 forks source link

Concern on the details of the comparison results in Table-2 #2

Closed PkuRainBow closed 3 years ago

PkuRainBow commented 4 years ago

Really nice paper!

We carefully read your work and find the experimental settings on Pascal-VOC in Table-2 (as shown below) is really interesting: on the last column of Table-2, all the methods only use 92 images as the labeled set and choose the train-aug set (10582) as the unlabeled set according to the code :

https://github.com/googleinterns/wss/blob/8069dbe8b68b409a891224508f35c6ae5ecec4c9/core/data_generator.py#L85-L104

and,

https://github.com/googleinterns/wss/blob/280cc1a6ceb5326044ee7521706d3d293c4aeb40/train_wss.py#L796

Our understanding is that the FLAGS.train_split_cls represents the set of unlabeled images used for training and its value is train_aug by default. So the number of unlabeled images is nearly more than 100x than the number of unlabeled images. Given that the total training iteration number is set as training_number_of_steps=30000, therefore, we will iterate the sampled 92 labeled images for nearly 30000x64/92=20869 epochs. Is my understanding correct?

If my understanding is correct, we are curious about whether training for so many epochs on the 92 labeled images is a good choice. Besides, as the train-aug set (10582) contains the 92 labeled images, so we guess all the methods also apply the pseudo-label based methods/consistency based methods on the labeled images (instead of only on the unlabeled images).

Great thanks and wait for your explanation if my understanding is wrong!

image

Yuliang-Zou commented 4 years ago

Hi @PkuRainBow 1) Yes, your understanding is correct. We here use the same number of iterations for all the data splits, this is because we need to iterate through the unlabeled set enough times (if you count the number of epochs based on the unlabeled set, then they are the same). 2) Yes, those 92 images are also in the unlabeled set. I follow the common practice in SSL classification here.

BTW, we sample those 92 images so that the number of pixels for each class is roughly balanced. You might not always get a good result if you pick arbitrary 92 images (see Appendix C).

PkuRainBow commented 3 years ago

@Yuliang-Zou Great thanks for your explanation. We still have a small concern about your experimental settings.

According to your explanation, in fact, your method will train over the 92 images (labeled set) for 20869 epochs, which might cause serious overfitting problems on the supervised loss training part. We also find that the authors of CutMix face the same challenge and we paste the discussion here: https://github.com/Britefury/cutmix-semisup-seg/issues/5#issuecomment-720367128

So we are really interested in how your experimental setting can address the overfitting problem? Hope for your explanation!

Yuliang-Zou commented 3 years ago

I don't have a clear answer yet. But I guess it could be related to the training schedule. In the beginning, the supervised loss dominates the optimization; as we train for more and more iterations, the unsupervised loss starts to take effects and gradually dominates the loss. Just for your reference, FixMatch (semi-supervised classification) has an experiment, training cifar10 on 10 images only, but it works quite well.

PkuRainBow commented 3 years ago

@Yuliang-Zou Thanks for your reply. The balance between the supervised loss and the unsupervised loss might be a good point to avoid this problem. If my understanding is correct, it is very important to ensure the unsupervised loss to dominate in the late stage. However, there seem no explicit mechanisms to ensure such a scheme, therefore, we guess that an explicit re-weighting scheme might address this problem.