MishaLaskin / curl

CURL: Contrastive Unsupervised Representation Learning for Sample-Efficient Reinforcement Learning
MIT License
561 stars 88 forks source link

Using cross-entropy loss does not penalize negative samples #9

Closed xlnwel closed 4 years ago

xlnwel commented 4 years ago

Hi,

I see that you use the cross-entropy(CE) loss for the contrastive learning. As far as I understand, this does not penalize the negative samples, as the CE loss gives zero weights to the non-diagonal entries in the [B, B] matrix. Do I make any mistake?

Best, Sherwin

MishaLaskin commented 4 years ago

https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html

On Wed, Jun 3, 2020 at 10:45 PM The Raven Chaser notifications@github.com wrote:

Hi,

I see that you use the cross-entropy(CE) loss for the contrastive learning. As far as I understand, this does not penalize the negative samples, as the CE loss gives zero weights to the non-diagonal entries in the [B, B] matrix. Do I make any mistake?

Best, Sherwin

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/MishaLaskin/curl/issues/9, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABHWQWMPSXZ6JR3VOGEK5FTRU4DENANCNFSM4NSGVMZQ .

MishaLaskin commented 4 years ago

CE penalizes entries labeled with 0 (off diagonal) and encourages entries labeled with 1 (on diagonal)

On Wed, Jun 3, 2020 at 10:52 PM Misha Laskin laskin.misha@gmail.com wrote:

https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html

On Wed, Jun 3, 2020 at 10:45 PM The Raven Chaser notifications@github.com wrote:

Hi,

I see that you use the cross-entropy(CE) loss for the contrastive learning. As far as I understand, this does not penalize the negative samples, as the CE loss gives zero weights to the non-diagonal entries in the [B, B] matrix. Do I make any mistake?

Best, Sherwin

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/MishaLaskin/curl/issues/9, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABHWQWMPSXZ6JR3VOGEK5FTRU4DENANCNFSM4NSGVMZQ .

xlnwel commented 4 years ago

My mistake. I took logits as probabilities. Thanks for your help:-)