KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.94k stars 658 forks source link

Multi-label version of SupConLoss #658

Open penguinwang96825 opened 1 year ago

penguinwang96825 commented 1 year ago

This paper introduced a new loss function called MultiSupCon that allows us to gain knowledge about the degree of label overlap between pairs of samples. I was just wondering if it's possible to integrate this loss into your library?

KevinMusgrave commented 1 year ago

I'd have to take a closer look at the paper. I'll leave this issue open in case someone wants to implement it.

penguinwang96825 commented 1 year ago

Thanks for the prompt reply!

The supervised contrastive loss (SupCon) are defined below:

$$\mathcal{L} = \sum\limits{i \in I} \frac{-1}{\lvert P(i) \rvert} \sum\limits{p \in P(i)} \log \frac{\exp(z_i \cdot zp / \tau)}{\sum\limits{a \in A(i)} \exp(z_i \cdot z_a / \tau)}$$

And the multi-label supervised contrastive loss (MultiSupCon) are the following:

$$\mathcal{L} = \sum\limits{i \in I} \frac{-1}{\lvert N(i) \rvert} \sum\limits{p \in N(i)} s_{i, p} \cdot \log \frac{\exp(z_i \cdot zp / \tau)}{\sum\limits{a \in A(i)} \exp(z_i \cdot z_a / \tau)}$$

where $s{i, p}$ is the assimilation value mentioned in the paper, which represents the weight ($0 < s{i, p} <1$) indicating the degree of overlap between the two sets of labels.

I looked into the pytorch_metric_learning code for SupCon loss, but it seems to be a little different from the one shown above. So it's kind of difficult for me to modify with little effort for the MultiSupCon loss. Basically, there are two challenges:

  1. The shape of the input labels.
  2. How to multiply the assimilation value with mean_log_prob_pos.

I have found the repo for the MultiSupCon loss.