Closed Mathilda88 closed 3 years ago
The cross entropy loss, implemented using torch.nn.CrossEntropyLoss
, takes two inputs.
The first argument is the list of logits, named predictions
in our code. We have a total of (B) x (S) predictions: S patches for B images. And each prediction is basically S + 1 logits. And the 0th entry of these logits contain the target.
The second argument indicates the target index among the S + 1 logits. So it is only a single index per each prediction. Therefore, the dimension of the second argument is (BS), not (BS) x (S + 1).
Therefore, it is incorrect to say that we are computing cross entropy against an all-zero vector. We are computing cross entropy against a one-hot vector that is nonzero at index 0.
Hi everyone,
Can anyone here help me to understand the foundation of this code. I understand the code up to the logits. I know that logits has S+1 columns and S rows. In each rows, first S entities comes from the multiplication of each sample with the negative patches and the last one is the positive one. But after it, I expected a sum to be applied on the logits but I can not understand the remaining part of it. Why we need to calculate the cross entropy against an all-zero vector.
Any help would be appreciated, Thanks
import torch cross_entropy_loss = torch.nn.CrossEntropyLoss()
Input: f_q (BxCxS) and sampled features from H(G_enc(x))
Input: f_k (BxCxS) are sampled features from H(G_enc(G(x))
Input: tau is the temperature used in PatchNCE loss.
Output: PatchNCE loss
def PatchNCELoss(f_q, f_k, tau=0.07):
batch size, channel size, and number of sample locations