XinhaoMei / audio-text_retrieval

Implementation of our paper 'On Metric Learning For Audio-Text Cross-Modal Retrieval'
43 stars 5 forks source link

Understanding the NT-Xent loss function #5

Open Vedanshi-Shah opened 1 year ago

Vedanshi-Shah commented 1 year ago

Could you explain the significance of mask in the NT-Xent loss function?

mask = labels.expand(n, n).eq(labels.expand(n, n).t()).to(a2t.device)
mask_diag = mask.diag()
mask_diag = torch.diag_embed(mask_diag)
mask = mask ^ mask_diag

a2t_loss = - self.loss(a2t).masked_fill(mask, 0).diag().mean()
t2a_loss = - self.loss(t2a).masked_fill(mask, 0).diag().mean()

From what we have inferred, mask disregards the diagonal positive pairs, (i.e ( [i, i] ), but takes into account [i, j] (where i != j) positive pairs.

In the final a2t_loss calculation, we take the mean of diagonal values instead of taking the means of negative pairs. Since NT-Xent loss is supposed to account for the negative pairs similarity, how is that being calculated?

XinhaoMei commented 1 year ago

Softmax is applied in self.loss().在 2023年9月2日,21:40,Vedanshi Shah @.***> 写道: Could you explain the significance of mask in the NT-Xent loss function? mask = labels.expand(n, n).eq(labels.expand(n, n).t()).to(a2t.device) mask_diag = mask.diag() mask_diag = torch.diag_embed(mask_diag) mask = mask ^ mask_diag

a2t_loss = - self.loss(a2t).masked_fill(mask, 0).diag().mean() t2a_loss = - self.loss(t2a).masked_fill(mask, 0).diag().mean()

From what we have inferred, mask disregards the diagonal positive pairs, (i.e ( [i, i] ), but takes into account [i, j] (where i != j) positive pairs. In the final a2t_loss calculation, we take the mean of diagonal values instead of taking the means of negative pairs. Since NT-Xent loss is supposed to account for the negative pairs similarity, how is that being calculated?

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you are subscribed to this thread.Message ID: @.***>