KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.98k stars 658 forks source link

SupCon loss #281

Closed fjsj closed 3 years ago

fjsj commented 3 years ago

Hi, thanks for this library! I have a pytorch-metric-learning compatible version of the SupCon loss from paper Supervised Contrastive Learning. The base code is below. It's based on this implementation (BSD licensed).

Are you open for a PR to include this loss? Any advice on how to write automated tests for it? This loss is similar to NTXentLoss, so I guess I can use the test for it as basis.

from pytorch_metric_learning.distances import DotProductSimilarity
from pytorch_metric_learning.losses import GenericPairLoss
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu

class SupConLoss(GenericPairLoss):
    def __init__(self, temperature, **kwargs):
        super().__init__(mat_based_loss=True, **kwargs)
        self.temperature = temperature
        self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False)

    def _compute_loss(self, mat, pos_mask, neg_mask):
        sim_mat = mat / self.temperature
        sim_mat_max, _ = sim_mat.max(dim=1, keepdim=True)
        sim_mat = sim_mat - sim_mat_max.detach()  # for numerical stability

        denominator = lmu.logsumexp(
            sim_mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1
        )
        log_prob = sim_mat - denominator
        mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / (
            pos_mask.sum(dim=1) + c_f.small_val(sim_mat.dtype)
        )
        losses = self.temperature * mean_log_prob_pos

        return {
            "loss": {
                "losses": -losses,
                "indices": c_f.torch_arange_from_size(sim_mat),
                "reduction_type": "element",
            }
        }

    def get_default_reducer(self):
        return AvgNonZeroReducer()

    def get_default_distance(self):
        return DotProductSimilarity()
KevinMusgrave commented 3 years ago

Looks great! Yes, please make a PR.

You can add a unit test to the tests/losses folder. Call it test_supcon_loss.py.

Here's the test file for NTXentLoss, if you'd like a reference: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/tests/losses/test_ntxent_loss.py

Then run it with python -m unittest tests/losses/test_supcon_loss.py

Looking at your code, I just realized I should be detaching the max value in NTXentLoss. Interestingly I was able to reproduce MoCo on CIFAR10 with this loss, so I guess it's not a devastating bug 🤔

KevinMusgrave commented 3 years ago

You might find yourself copy-pasting a lot from the unit test for NTXentLoss. Currently there is no way around that because I haven't written any helper functions/classes specifically for the test suite.

KevinMusgrave commented 3 years ago

After looking at the paper, I'm not sure if SupConLoss is any different from NTXentLoss.

fjsj commented 3 years ago

After looking at the paper, I'm not sure if SupConLoss is any different from NTXentLoss.

Maybe the way NTXentLoss is implemented on pytorch-metric-learning already covers the specifics of SupConLoss? On the paper, the main difference I see is a division by |P(i)| on L sup out.

KevinMusgrave commented 3 years ago

Ah I see. I'll think about this and get back to you

KevinMusgrave commented 3 years ago

I added PerAnchorReducer to convert unreduced pair losses to unreduced element-wise losses, which are then passed to an inner reducer.

from pytorch_metric_learning.reducers import PerAnchorReducer, AvgNonZeroReducer
from pytorch_metric_learning.losses import NTXentLoss

# equivalent to SupConLoss
loss_fn = NTXentLoss(reducer = PerAnchorReducer())

# the default reducer inside of PerAnchorReducer is MeanReducer
# but you can use any other reducer
reducer = PerAnchorReducer(reducer = AvgNonZeroReducer())
loss_fn = NTXentLoss(reducer = reducer)

Unfortunately, PerAnchorReducer adds computational overhead, because it has to "assemble" all the pairs per batch element.

Since SupConLoss seems to be quite popular, I think it makes sense to add the version you have proposed. So please feel free to open a pull request.

fjsj commented 3 years ago

Thanks! I've created the PR: https://github.com/KevinMusgrave/pytorch-metric-learning/pull/288

KevinMusgrave commented 3 years ago

Looks good, I'll take a closer look (probably sometime next week)

KevinMusgrave commented 3 years ago

Available as:

pip install pytorch-metric-learning==0.9.98.dev0
fjsj commented 3 years ago

That's great, thanks!