jefflai108 / Contrastive-Predictive-Coding-PyTorch

Contrastive Predictive Coding for Automatic Speaker Verification
MIT License
472 stars 96 forks source link

Some Trouble in Understanding #5

Closed cyl250 closed 4 years ago

cyl250 commented 4 years ago

I had some trouble to understand the realization of infoNCE loss function. I don't understand the How torch.diag() could represent infoNCE loss.

jefflai108 commented 4 years ago

If you look at eq. (3) and (4) in https://arxiv.org/pdf/1807.03748.pdf

total = torch.mm(encode_samples[i], torch.transpose(pred[i],0,1)): For each sample in the minibatch, this computes the dot product between z and c for positive and negative samples. Note that since we sample negative samples from the same minibatch, total is a matrix of B by B, where B is the batch size. ("dot product between z and c" is equation (3) without the log.

InfoNCE computes the expectation over log-softmax of density ratios of positive samples. That's why the following equation first take lsoftmax on the matrix total, and take the diagonal, which corresponds to positive samples (Equation (4)). nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor

cyl250 commented 4 years ago

Thank you very much . Very NIU

KinWaiCheuk commented 3 years ago

If you look at eq. (3) and (4) in https://arxiv.org/pdf/1807.03748.pdf

total = torch.mm(encode_samples[i], torch.transpose(pred[i],0,1)): For each sample in the minibatch, this computes the dot product between z and c for positive and negative samples. Note that since we sample negative samples from the same minibatch, total is a matrix of B by B, where B is the batch size. ("dot product between z and c" is equation (3) without the `log

That means in the end no negative samples are used? Since by taking diag, we remove all the negative sample pairs, leaving behind only the positive terms, am I right?

bfs18 commented 3 years ago

That means in the end no negative samples are used? Since by taking diag, we remove all the negative sample pairs, leaving behind only the positive terms, am I right?

When calculating lsoftmax, the negative samples are used in calculating the denominator.

bfs18 commented 3 years ago

Hi @jefflai108 In your implementation, it seems that the negative samples lack diversity. Only the samples within the timestep window with mismatched time step are used.