zhihanyue / ts2vec

A universal time series representation learning framework
MIT License
619 stars 148 forks source link

Question on Instance Constrastive-Loss #27

Closed KiwiAthlete closed 1 year ago

KiwiAthlete commented 1 year ago

Problem Description

@yuezhihan Thanks for making the code available, really nice approach.

Going through the code, I have some difficulties understanding the Instance Constrastive-Loss. Starting from your paper explanation

I was trying to understand your implementation, in particular the final part , where you take specific entries across batches of the logits-tensor.

Details

For ease of explanation, I created a small example


# Create time series tensors (B:3,T:4,C:1)
z1 = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]], dtype=torch.float32).reshape(3,4,1)
z2 = torch.tensor([[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], dtype=torch.float32).reshape(3,4,1)

# get batch size and length of time series
B, T = z1.size(0), z1.size(1)

# handle edge case of single batch element
if B == 1:
    return z1.new_tensor(0.)

# concatenate z1 and z2 along the batch dimension to form a tensor of shape (2B, T, C)
z = torch.cat([z1, z2], dim=0)

# transpose z to shape (T, 2B, C)
z = z.transpose(0, 1)

# calculate the dot product between z and its transpose to get a similarity matrix of shape (T, 2B, 2B)
sim = torch.matmul(z, z.transpose(1, 2))

# extract the lower triangular part of sim, excluding the main diagonal that reflects self-similarity (T, 2B, 2B-1)
logits = torch.tril(sim, diagonal=-1)[:, :, :-1]

# add the upper triangular part of sim, excluding the main diagonal that reflects self-similarity (T, 2B, 2B-1)
logits += torch.triu(sim, diagonal=1)[:, :, 1:]

# apply log_softmax to logits
logits = -F.log_softmax(logits, dim=-1)

# use arange to create a tensor of indices 
i = torch.arange(B, device=z1.device)

# calculate the mean of the logits along the specified entries and across batches
loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2

For illustration purposes, let's focus on the first batch of the similarity matrix sim

tensor([[[  1.,   5.,   9.,  13.,  17.,  21.],
         [  5.,  25.,  45.,  65.,  85., 105.],
         [  9.,  45.,  81., 117., 153., 189.],
         [ 13.,  65., 117., 169., 221., 273.],
         [ 17.,  85., 153., 221., 289., 357.],
         [ 21., 105., 189., 273., 357., 441.]],

Calling torch.tril(sim, diagonal=-1)[:, :, :-1] gives the lower triangular matrix

tensor([[  0.,   0.,   0.,   0.,   0.],
        [  5.,   0.,   0.,   0.,   0.],
        [  9.,  45.,   0.,   0.,   0.],
        [ 13.,  65., 117.,   0.,   0.],
        [ 17.,  85., 153., 221.,   0.],
        [ 21., 105., 189., 273., 357.]])

Calling torch.triu(sim, diagonal=1)[:, :, 1:] gives the upper triangular matrix

tensor([[  5.,   9.,  13.,  17.,  21.],
        [  0.,  45.,  65.,  85., 105.],
        [  0.,   0., 117., 153., 189.],
        [  0.,   0.,   0., 221., 273.],
        [  0.,   0.,   0.,   0., 357.],
        [  0.,   0.,   0.,   0.,   0.]])

Having the diagonal and the remaining 0-elements on it removed, gives a matrix of pairwise dot-products logits = torch.tril(sim, diagonal=-1)[:, :, :-1] + torch.triu(sim, diagonal=1)[:, :, 1:]

print(logits)

tensor([[  5.,   9.,  13.,  17.,  21.],
        [  5.,  45.,  65.,  85., 105.],
        [  9.,  45., 117., 153., 189.],
        [ 13.,  65., 117., 221., 273.],
        [ 17.,  85., 153., 221., 357.],
        [ 21., 105., 189., 273., 357.]])

Applying the soft-max and the negative-log gives the negative log-likelihood as follows

logits= -F.log_softmax(logits, dim=-1)

We then use the batch-size B=3 to create an index for extracting certain parts of the logits-tensor

i = torch.arange(B, device=z1.device) 
    # i = tensor([0, 1, 2])
loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
    # (B + i - 1) = tensor([2, 3, 4])
    # (B + i) = tensor([3, 4, 5])

Questions

Based on the above example, my questions are as follows:

tribeband commented 1 month ago

what is the explanation? I am confused about this as well