Closed cyl250 closed 4 years ago
If you look at eq. (3) and (4) in https://arxiv.org/pdf/1807.03748.pdf
total = torch.mm(encode_samples[i], torch.transpose(pred[i],0,1))
: For each sample in the minibatch, this computes the dot product between z and c for positive and negative samples. Note that since we sample negative samples from the same minibatch, total
is a matrix of B by B, where B is the batch size. ("dot product between z and c" is equation (3) without the log
.
InfoNCE computes the expectation over log-softmax of density ratios of positive samples. That's why the following equation first take lsoftmax
on the matrix total
, and take the diagonal, which corresponds to positive samples (Equation (4)).
nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor
Thank you very much . Very NIU
If you look at eq. (3) and (4) in https://arxiv.org/pdf/1807.03748.pdf
total = torch.mm(encode_samples[i], torch.transpose(pred[i],0,1))
: For each sample in the minibatch, this computes the dot product between z and c for positive and negative samples. Note that since we sample negative samples from the same minibatch,total
is a matrix of B by B, where B is the batch size. ("dot product between z and c" is equation (3) without the `log
That means in the end no negative samples are used? Since by taking diag
, we remove all the negative sample pairs, leaving behind only the positive terms, am I right?
That means in the end no negative samples are used? Since by taking
diag
, we remove all the negative sample pairs, leaving behind only the positive terms, am I right?
When calculating lsoftmax, the negative samples are used in calculating the denominator.
Hi @jefflai108 In your implementation, it seems that the negative samples lack diversity. Only the samples within the timestep
window with mismatched time step are used.
I had some trouble to understand the realization of infoNCE loss function. I don't understand the How torch.diag() could represent infoNCE loss.