I have a question regarding the implementation of instance_contrastive_loss
def instance_contrastive_loss(z1, z2):
B, T = z1.size(0), z1.size(1)
if B == 1:
# contrastive loss requires pair.
return z1.new_tensor(0.)
z = torch.cat([z1, z2], dim=0) # 2B x T x C
z = z.transpose(0, 1) # T x 2B x C
sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B
logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1)
logits += torch.triu(sim, diagonal=1)[:, :, 1:]
logits = -F.log_softmax(logits, dim=-1)
i = torch.arange(B, device=z1.device)
loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
return loss
In your implementation, you calculate the logits until [:,:,:-1] for tril and [:,:,1:] for triu. Why is this so? is there something that I have missed?
Hello, thank you for sharing your work!
I have a question regarding the implementation of instance_contrastive_loss
In your implementation, you calculate the logits until [:,:,:-1] for tril and [:,:,1:] for triu. Why is this so? is there something that I have missed?
thank you in advance!
best,