taesungp / contrastive-unpaired-translation

Contrastive unpaired image-to-image translation, faster and lighter training than cyclegan (ECCV 2020, in PyTorch)
https://taesung.me/ContrastiveUnpairedTranslation/
Other
2.23k stars 417 forks source link

Understanding the foundation of the psudo code #97

Closed Mathilda88 closed 3 years ago

Mathilda88 commented 3 years ago

Hi everyone,

Can anyone here help me to understand the foundation of this code. I understand the code up to the logits. I know that logits has S+1 columns and S rows. In each rows, first S entities comes from the multiplication of each sample with the negative patches and the last one is the positive one. But after it, I expected a sum to be applied on the logits but I can not understand the remaining part of it. Why we need to calculate the cross entropy against an all-zero vector.

Any help would be appreciated, Thanks

import torch cross_entropy_loss = torch.nn.CrossEntropyLoss()

Input: f_q (BxCxS) and sampled features from H(G_enc(x))

Input: f_k (BxCxS) are sampled features from H(G_enc(G(x))

Input: tau is the temperature used in PatchNCE loss.

Output: PatchNCE loss

def PatchNCELoss(f_q, f_k, tau=0.07):

batch size, channel size, and number of sample locations

B, C, S = f_q.shape

# calculate v * v+: BxSx1
l_pos = (f_k * f_q).sum(dim=1)[:, :, None]

# calculate v * v-: BxSxS
l_neg = torch.bmm(f_q.transpose(1, 2), f_k)

# The diagonal entries are not negatives. Remove them.
identity_matrix = torch.eye(S)[None, :, :]
l_neg.masked_fill_(identity_matrix, -float('inf'))

# calculate logits: (B)x(S)x(S+1)
logits = torch.cat((l_pos, l_neg), dim=2) / tau

# return PatchNCE loss
predictions = logits.flatten(0, 1)
targets = torch.zeros(B * S, dtype=torch.long)
return cross_entropy_loss(predictions, targets)
taesungp commented 3 years ago

The cross entropy loss, implemented using torch.nn.CrossEntropyLoss, takes two inputs.

The first argument is the list of logits, named predictions in our code. We have a total of (B) x (S) predictions: S patches for B images. And each prediction is basically S + 1 logits. And the 0th entry of these logits contain the target.

The second argument indicates the target index among the S + 1 logits. So it is only a single index per each prediction. Therefore, the dimension of the second argument is (BS), not (BS) x (S + 1).

Therefore, it is incorrect to say that we are computing cross entropy against an all-zero vector. We are computing cross entropy against a one-hot vector that is nonzero at index 0.