TencentYoutuResearch / Classification-SemiCLS

Code for CVPR 2022 paper “Class-Aware Contrastive Semi-Supervised Learning”
Other
93 stars 13 forks source link

question about implementation of eq(5) #9

Closed randydkx closed 1 year ago

randydkx commented 2 years ago

The implementation of eq (5) seems to be different from the definition, as for i = j, w_{ij}^{target} will be {qi}^2 * w{ij}^{clacon} which is based on the implementation of 'SoftSupConLoss' function, line 55.

KaiWU5 commented 2 years ago

The equation is the hard version of class-aware contrastive matrix without re-weighting. Equation 5 is the select_matrix * mask in line 55 and score_mask is the re-weighting factor.

Also note that equation8 dose have a typo and we have already changed on Arxiv #7 https://arxiv.org/abs/2203.02261

randydkx commented 2 years ago

@KaiWU5 Thanks for your reply, I know that but I find the re-weighting factor is qi ^2 for i = j in implementation, it should be 1 in your paper.

KaiWU5 commented 2 years ago

Thanks for your notice, sorry for the misleading. i and j mean different indices of columns and rows of the mask matrix. When writing the eq(6) at i=j (same image's aug), we want to express same sample's embedding should be the same. But ignored in actual implementation (the denominator of Eq(8))( same as in many self-supervised learning or semi-supervised learning methods), the same sample's embedding doesn't count into the final loss. The code is following Eq(8) I will add some comments to the code to clarify this problem. Thanks

JCZ404 commented 1 year ago

@KaiWU5 I guess you may don't understand the problem proposed by @randydkx. In your code, you first calculate the score_mask according to Eq(6), it owns the shape of (batch_size, batch_size), so the diagonal element value is q_i^2. Then you duplicated this score_mask n_views times, so the shape of this re-weighting mask becomes (batch_size n_views, batch_size n_views), so the re-weighting factors of between i-th image and its different augmented view samples are q_i^2, and these samples are regarded as positive samples and calculate the loss like the first term in the Eq(9). But in your paper, the re-weighting factors of these positive samples is 1. image

image

Really hope you could give a detailed explanation about this, thanks!

KaiWU5 commented 1 year ago

Thanks for your kind explanation. I think it is a bug in the code because mathematically setting the weight to 1 should be better. After setting the diagonal to 1, I achieved 75.88 on cifar2500 but is still in the range from paper 75.7±0.63.

My guess is that an aug has several hundred pairs, changing only one pair's weight is not critical. I changed the code https://github.com/TencentYoutuResearch/Classification-SemiCLS/blob/main/loss/soft_supconloss.py#L78.