LTH14 / targeted-supcon

A PyTorch implementation of the paper Targeted Supervised Contrastive Learning for Long-tailed Recognition
MIT License
93 stars 13 forks source link

Seems the code is inconsistent with the paper #12

Open liluhu0 opened 1 year ago

liluhu0 commented 1 year ago

First of all, thanks for such an excellent work! And then I have a doubt about your code that seems to be inconsistent with the paper.

Your code

loss_target = - torch.sum((mask_pos_view_target * log_prob).sum(1) / mask_pos_view.sum(1)) / mask_pos_view.shape[0]

But according to the formula in the paper it seems to be

loss_target = - torch.sum((mask_pos_view_target * log_prob).sum(1)) / mask_pos_view.shape[0] That is, the target loss in the paper does not seem to be divided by (k+1), which will affect the choice of tw. Paper shows the optimal tw=0.2, so is this tw calculated according to the formula in paper or in the code? Look forward to your answer, thank you!

LTH14 commented 1 year ago

Hi, thanks for pointing this out! I just double-checked the paper -- it seems that the paper misses a parenthesis for the two contrastive losses. Please follow the code for this, and the optimal tw is computed using the code.

liluhu0 commented 1 year ago

Okay, thanks for your response!

suminRoh commented 1 year ago

Hi, I have a question. Doesn't (k+1) appear in mask_pos_view.shape[0] in the code? I think the correct code is: loss_target = - torch.sum((mask_pos_view_target * log_prob).sum(1) / mask_pos_view.sum(1))

LTH14 commented 1 year ago

That's true, but you still need to divide by the batch size, which is mask_pos_view.shape[0]

suminRoh commented 1 year ago

Why do I have to divide by the batch size? In main_moco_supcon_imaba.py, does not the AverageMeter of losses compute the average loss?

Also, if I have to divide by the batch size, which is mask_pos_view.shape[0], then doesn't loss_class have to be divided by mask_pos_view.shape[0] twice, because of the batch size and (k+1)?

LTH14 commented 1 year ago

Basically, the loss for each data point is (mask_pos_view_target log_prob).sum(1) / mask_pos_view.sum(1). Then we sum all and divide the batch size, which is equivalent to computing the average loss: torch.sum((mask_pos_view_target log_prob).sum(1) / mask_pos_view.sum(1)) / mask_pos_view.shape[0

suminRoh commented 1 year ago

I understand. Thank you for explaining in detail !