Shen-Lab / GraphCL

[NeurIPS 2020] "Graph Contrastive Learning with Augmentations" by Yuning You, Tianlong Chen, Yongduo Sui, Ting Chen, Zhangyang Wang, Yang Shen
MIT License
547 stars 103 forks source link

Questions of Loss function calculation #2

Closed cGy147 closed 3 years ago

cGy147 commented 3 years ago

Hi: In GraphCL/unsupervised_TU/gsimclr:

def loss_cal(self, x, x_aug):
T = 0.2
batch_size, _ = x.size()
x_abs = x.norm(dim=1)
x_aug_abs = x_aug.norm(dim=1)

sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs)
sim_matrix = torch.exp(sim_matrix / T)
pos_sim = sim_matrix[range(batch_size), range(batch_size)]
loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
loss = - torch.log(loss).mean()

return loss

Why the denominator of loss need subtract the pos_sim?

yyou1996 commented 3 years ago

Hi @cGy147,

Thanks for your interest. Recalling the GraphCL loss is defined as: - log( sim(positive pair) / sum of sim(negative pairs) ), where negative pairs excluding the positive ones (diagonal values of sim_matrix). Thus subtraction is performed here.

cGy147 commented 3 years ago

But in your paper: image I think your implemention of your code is different from the equation 3, which consists of 1 positive pair and N-1 negative pairs. Am I misunderstanding?

yyou1996 commented 3 years ago

Hi @cGy147,

Yeah thanks a lot for pointing this out. The formulation should be corrected. The denominator should be sum^N_{k=1, k \neq i} exp(sim(kth neg pair)/t). I will submit a corrected version arXiv today.