@yuezhihan Thanks for making the code available, really nice approach.
Going through the code, I have some difficulties understanding the Instance Constrastive-Loss. Starting from your paper explanation
I was trying to understand your implementation, in particular the final part , where you take specific entries across batches of the logits-tensor.
Details
For ease of explanation, I created a small example
# Create time series tensors (B:3,T:4,C:1)
z1 = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]], dtype=torch.float32).reshape(3,4,1)
z2 = torch.tensor([[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], dtype=torch.float32).reshape(3,4,1)
# get batch size and length of time series
B, T = z1.size(0), z1.size(1)
# handle edge case of single batch element
if B == 1:
return z1.new_tensor(0.)
# concatenate z1 and z2 along the batch dimension to form a tensor of shape (2B, T, C)
z = torch.cat([z1, z2], dim=0)
# transpose z to shape (T, 2B, C)
z = z.transpose(0, 1)
# calculate the dot product between z and its transpose to get a similarity matrix of shape (T, 2B, 2B)
sim = torch.matmul(z, z.transpose(1, 2))
# extract the lower triangular part of sim, excluding the main diagonal that reflects self-similarity (T, 2B, 2B-1)
logits = torch.tril(sim, diagonal=-1)[:, :, :-1]
# add the upper triangular part of sim, excluding the main diagonal that reflects self-similarity (T, 2B, 2B-1)
logits += torch.triu(sim, diagonal=1)[:, :, 1:]
# apply log_softmax to logits
logits = -F.log_softmax(logits, dim=-1)
# use arange to create a tensor of indices
i = torch.arange(B, device=z1.device)
# calculate the mean of the logits along the specified entries and across batches
loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
For illustration purposes, let's focus on the first batch of the similarity matrix sim
Having the diagonal and the remaining 0-elements on it removed, gives a matrix of pairwise dot-products logits = torch.tril(sim, diagonal=-1)[:, :, :-1] + torch.triu(sim, diagonal=1)[:, :, 1:]
Applying the soft-max and the negative-log gives the negative log-likelihood as follows
logits= -F.log_softmax(logits, dim=-1)
We then use the batch-size B=3 to create an index for extracting certain parts of the logits-tensor
i = torch.arange(B, device=z1.device)
# i = tensor([0, 1, 2])
loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
# (B + i - 1) = tensor([2, 3, 4])
# (B + i) = tensor([3, 4, 5])
Questions
Based on the above example, my questions are as follows:
How does loss = (log_prob[:, i, B + i - 1].mean() + log_prob[:, B + i, i].mean()) / 2 relate to the above equation?
Why do we extract logits[:, i, B + i - 1] parts from the upper-diagonal matrix? What do these entries mean? Are these the positive examples? Why do we extract exactly these entries?
Why do we extract logits[:, B + i, i] parts from the lower-diagonal matrix? What do these entries mean? Are these the negative examples? Why do we extract exactly these entries?
Why do we create the index i using the batch-size B, as we already take the mean across all batches?
Problem Description
@yuezhihan Thanks for making the code available, really nice approach.
Going through the code, I have some difficulties understanding the Instance Constrastive-Loss. Starting from your paper explanation
I was trying to understand your implementation, in particular the final part , where you take specific entries across batches of the logits-tensor.
Details
For ease of explanation, I created a small example
For illustration purposes, let's focus on the first batch of the similarity matrix
sim
Calling
torch.tril(sim, diagonal=-1)[:, :, :-1]
gives the lower triangular matrixCalling
torch.triu(sim, diagonal=1)[:, :, 1:]
gives the upper triangular matrixHaving the diagonal and the remaining 0-elements on it removed, gives a matrix of pairwise dot-products
logits = torch.tril(sim, diagonal=-1)[:, :, :-1] + torch.triu(sim, diagonal=1)[:, :, 1:]
Applying the soft-max and the negative-log gives the negative log-likelihood as follows
We then use the batch-size
B=3
to create an index for extracting certain parts of the logits-tensorQuestions
Based on the above example, my questions are as follows:
loss = (log_prob[:, i, B + i - 1].mean() + log_prob[:, B + i, i].mean()) / 2
relate to the above equation?logits[:, i, B + i - 1]
parts from the upper-diagonal matrix? What do these entries mean? Are these the positive examples? Why do we extract exactly these entries?logits[:, B + i, i]
parts from the lower-diagonal matrix? What do these entries mean? Are these the negative examples? Why do we extract exactly these entries?i
using the batch-sizeB
, as we already take the mean across all batches?