wzhouad / Contra-OOD

Source code for paper "Contrastive Out-of-Distribution Detection for Pretrained Transformers", EMNLP 2021
MIT License
40 stars 9 forks source link

Mask is always zero in Margin-based Contrastive Loss #5

Closed OrigamiDream closed 2 years ago

OrigamiDream commented 2 years ago

Reference: https://github.com/wzhouad/Contra-OOD/blob/main/model.py#L40

I have question about the implementation of Margin-based Contrastive Loss

mask = (labels.unsqueeze(1) == labels.unsqueeze(0)).float()

If the batch size is 64, the labels variable have a shape of [64] When the above code performs, ([64, 1] == [1, 64]).float()[64, 64], which is exact 2D diagonal matrix.

mask = mask - torch.diag(torch.diag(mask))

But the problem is on the second line of code. If torch.diag(mask) performs, the result has a shape of [64] that is one-filled vector: $[1, 1, 1, ...]$ Therefore, the result of torch.diag(torch.diag(mask)) is exactly same with the mask, which is exact 2D diagonal matrix. Furthermore, if you subtract the result from mask, eventually the mask is always zero-filled matrix. Eventually, the mask variable have no power for gradient descending.

Is this really on your purpose?

I thought the mask variable is used for distinguishing $P(i)$ and $N(i)$ in equation. Is this right? Or am I missing a point?

OrigamiDream commented 2 years ago

That was my mistake, problem solved.