zhihanyue / ts2vec

A universal time series representation learning framework
MIT License
593 stars 143 forks source link

question regarding the implementation of instance contrastive loss #42

Open jaeho3690 opened 10 months ago

jaeho3690 commented 10 months ago

Hello, thank you for sharing your work!

I have a question regarding the implementation of instance_contrastive_loss

def instance_contrastive_loss(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if B == 1:
        # contrastive loss requires pair.
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=0)  # 2B x T x C
    z = z.transpose(0, 1)  # T x 2B x C
    sim = torch.matmul(z, z.transpose(1, 2))  # T x 2B x 2B
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # T x 2B x (2B-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)

    i = torch.arange(B, device=z1.device)
    loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
    return loss

In your implementation, you calculate the logits until [:,:,:-1] for tril and [:,:,1:] for triu. Why is this so? is there something that I have missed?

thank you in advance!

best,