Closed biggzlar closed 4 years ago
Hi Lennart, TLDR: the other elements in the batch act as negative examples
So N is the batch size.
When we matrix multiply the two as in line 109:
logits = torch.matmul(predictions, positive.t())
the results is logits, which is an N,N matrix.
Element i,j of this matrix represents the dot product of the ith global vector in the batch with the jth feature map (at y,x spatial location) in the batch. When i=j, we have a "positive pair" because both elements of the pair are at the same batch index, which means they are from consecutive examples in an episode. When i !=j the pair are not from consecutive examples and so they are a "negative pair". So the diagonals of "logits" are the dot product of positive pairs and the off-diagonals are negative pairs.
The contrastive loss can be seen as classification of the positive pair vs. the other negative pairs, so that's why we use crossentropy
So if we consider each of the N rows of "logits" as logits for a N separate classification problems, then the correct class for row 0 is the 0th element, for row 1 it's the 1st element, for row 2 it's the second, for row N-1, it's the N-1st element (because those are the elements on the diagonals of the matrix) Hence, the target class for row 0 is 0, row 1 is 1, row N-1 is N-1, so "target" will be [0,1,..N-1] aka torch.arange(N)
Let me know if that helps!
Hey Evan, Thanks a lot for your time and a detailed explanation. That makes it a lot clearer! I suppose I was thrown off by the distinction of different class labels, instead of having a binary cross-entropy for positive and negative samples - but I see now, how this makes sense.
I hope you don't mind me asking a follow-up question. Is it true that this loss only captures the global-local objective? Maybe I am failing to see the local-local element?
Hi Lennart, No problem! Yeah you're correct! The file "global_local_infonce.py" is just the ablation where we remove the local-local loss. The full loss used in the paper with both terms is in stdim.py
Ha, oh well. Should have read the filenames more carefully. Thanks a lot for the speedy support.
https://github.com/mila-iqia/atari-representation-learning/blob/59f3e3b94c4a0a61a6cb63acf247dd5a8662fadd/atariari/methods/global_local_infonce.py#L110
This line is a little confusing to me. How does the range over N samples correspond to our cross-entropy targets? All help is highly appreciated!