KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.98k stars 658 forks source link

NTXent Loss Gradient Flow #442

Closed YilmazKadir closed 2 years ago

YilmazKadir commented 2 years ago

Do I need to take care of anything manually while using NTXent Loss in a SimCLR like network (stop gradients etc.) to avoid updating the weights multiple time, as in a SimCLR-like network, same network weights have connections to NTXent Loss via different samples. Can you briefly explain how you handle gradient updates of NTXent Loss.

KevinMusgrave commented 2 years ago

I don't think SimCLR uses stop gradients, but SimSiam does. The paper contains some pseudocode which you can follow. Basically you should call .detach() on one of the inputs before passing it into the NTXent loss.