mila-iqia / atari-representation-learning

Code for "Unsupervised State Representation Learning in Atari"
https://arxiv.org/abs/1906.08226
MIT License
240 stars 50 forks source link

Intuition on cross entropy. #58

Closed biggzlar closed 4 years ago

biggzlar commented 4 years ago

https://github.com/mila-iqia/atari-representation-learning/blob/59f3e3b94c4a0a61a6cb63acf247dd5a8662fadd/atariari/methods/global_local_infonce.py#L110

This line is a little confusing to me. How does the range over N samples correspond to our cross-entropy targets? All help is highly appreciated!

eracah commented 4 years ago

Hi Lennart, TLDR: the other elements in the batch act as negative examples

So N is the batch size.

When we matrix multiply the two as in line 109: logits = torch.matmul(predictions, positive.t()) the results is logits, which is an N,N matrix.

Element i,j of this matrix represents the dot product of the ith global vector in the batch with the jth feature map (at y,x spatial location) in the batch. When i=j, we have a "positive pair" because both elements of the pair are at the same batch index, which means they are from consecutive examples in an episode. When i !=j the pair are not from consecutive examples and so they are a "negative pair". So the diagonals of "logits" are the dot product of positive pairs and the off-diagonals are negative pairs.

The contrastive loss can be seen as classification of the positive pair vs. the other negative pairs, so that's why we use crossentropy

So if we consider each of the N rows of "logits" as logits for a N separate classification problems, then the correct class for row 0 is the 0th element, for row 1 it's the 1st element, for row 2 it's the second, for row N-1, it's the N-1st element (because those are the elements on the diagonals of the matrix) Hence, the target class for row 0 is 0, row 1 is 1, row N-1 is N-1, so "target" will be [0,1,..N-1] aka torch.arange(N)

Let me know if that helps!

biggzlar commented 4 years ago

Hey Evan, Thanks a lot for your time and a detailed explanation. That makes it a lot clearer! I suppose I was thrown off by the distinction of different class labels, instead of having a binary cross-entropy for positive and negative samples - but I see now, how this makes sense.

I hope you don't mind me asking a follow-up question. Is it true that this loss only captures the global-local objective? Maybe I am failing to see the local-local element?

eracah commented 4 years ago

Hi Lennart, No problem! Yeah you're correct! The file "global_local_infonce.py" is just the ablation where we remove the local-local loss. The full loss used in the paper with both terms is in stdim.py

biggzlar commented 4 years ago

Ha, oh well. Should have read the filenames more carefully. Thanks a lot for the speedy support.