Closed fjsj closed 3 years ago
Looks great! Yes, please make a PR.
You can add a unit test to the tests/losses
folder. Call it test_supcon_loss.py
.
Here's the test file for NTXentLoss, if you'd like a reference: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/tests/losses/test_ntxent_loss.py
Then run it with python -m unittest tests/losses/test_supcon_loss.py
Looking at your code, I just realized I should be detaching the max value in NTXentLoss. Interestingly I was able to reproduce MoCo on CIFAR10 with this loss, so I guess it's not a devastating bug 🤔
You might find yourself copy-pasting a lot from the unit test for NTXentLoss. Currently there is no way around that because I haven't written any helper functions/classes specifically for the test suite.
After looking at the paper, I'm not sure if SupConLoss is any different from NTXentLoss.
After looking at the paper, I'm not sure if SupConLoss is any different from NTXentLoss.
Maybe the way NTXentLoss
is implemented on pytorch-metric-learning already covers the specifics of SupConLoss
?
On the paper, the main difference I see is a division by |P(i)|
on L sup out.
Ah I see. I'll think about this and get back to you
I added PerAnchorReducer
to convert unreduced pair losses to unreduced element-wise losses, which are then passed to an inner reducer.
from pytorch_metric_learning.reducers import PerAnchorReducer, AvgNonZeroReducer
from pytorch_metric_learning.losses import NTXentLoss
# equivalent to SupConLoss
loss_fn = NTXentLoss(reducer = PerAnchorReducer())
# the default reducer inside of PerAnchorReducer is MeanReducer
# but you can use any other reducer
reducer = PerAnchorReducer(reducer = AvgNonZeroReducer())
loss_fn = NTXentLoss(reducer = reducer)
Unfortunately, PerAnchorReducer
adds computational overhead, because it has to "assemble" all the pairs per batch element.
Since SupConLoss seems to be quite popular, I think it makes sense to add the version you have proposed. So please feel free to open a pull request.
Thanks! I've created the PR: https://github.com/KevinMusgrave/pytorch-metric-learning/pull/288
Looks good, I'll take a closer look (probably sometime next week)
Available as:
pip install pytorch-metric-learning==0.9.98.dev0
That's great, thanks!
Hi, thanks for this library! I have a pytorch-metric-learning compatible version of the SupCon loss from paper Supervised Contrastive Learning. The base code is below. It's based on this implementation (BSD licensed).
Are you open for a PR to include this loss? Any advice on how to write automated tests for it? This loss is similar to NTXentLoss, so I guess I can use the test for it as basis.