It was a good article, and I was interested in the idea of DSCL, so I took a look at the code implementation.
It can be seen that the implementation of DSCL is based on Supervised Contrastive Learning. But when I looked at the code, I found that the formula 9 in the article is different from the code. Below is the code for the project.
class_weighted = torch.ones_like(mask) / (mask.sum(dim=1, keepdim=True) - 1.0 + 1e-12) * self.weighted_beta class_weighted = class_weighted.scatter(1, torch.arange(batch_size).view(-1, 1).to(device) + batch_size, 1.0)
Does self.weighted_beta mean what is $\alpha$ in the paper, and if so, how does formula 9 represent it?
In addition, there is the calculation of mean_log_prob_pos, the form of calculation in the code is as follows:
mean_log_prob_pos = (mask * log_prob * class_weighted).sum(1) / (mask * class_weighted).sum(1)
Isn't this the same calculation as the SCL calculation? That is class_weighted is not involved in the calculation.
The above two questions are my current doubts, if you see this question, please also give a detailed explanation, thank you.
It was a good article, and I was interested in the idea of DSCL, so I took a look at the code implementation.
It can be seen that the implementation of DSCL is based on Supervised Contrastive Learning. But when I looked at the code, I found that the formula 9 in the article is different from the code. Below is the code for the project.
class_weighted = torch.ones_like(mask) / (mask.sum(dim=1, keepdim=True) - 1.0 + 1e-12) * self.weighted_beta class_weighted = class_weighted.scatter(1, torch.arange(batch_size).view(-1, 1).to(device) + batch_size, 1.0)
Doesself.weighted_beta
mean what is $\alpha$ in the paper, and if so, how does formula 9 represent it?In addition, there is the calculation of
mean_log_prob_pos
, the form of calculation in the code is as follows:mean_log_prob_pos = (mask * log_prob * class_weighted).sum(1) / (mask * class_weighted).sum(1)
Isn't this the same calculation as the SCL calculation? That isclass_weighted
is not involved in the calculation.The above two questions are my current doubts, if you see this question, please also give a detailed explanation, thank you.