KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.98k stars 658 forks source link

Add Circle Loss #48

Closed AlenUbuntu closed 4 years ago

AlenUbuntu commented 4 years ago

Hello: I'd like to add Circle Loss into the package and have completed the coding.

FYI, Here is the CVPR 2020 paper that introduces circle loss: https://arxiv.org/abs/2002.10857

How can I add this to your repo?

KevinMusgrave commented 4 years ago

That would be awesome! Here's how I think it should be done:

  1. The class name should be CircleLoss

  2. It should be located in losses/circle_loss.py

  3. It should be added to losses/__init__.py like this:

    ...
    from .arcface_loss import ArcFaceLoss
    from .circle_loss import CircleLoss
    from .contrastive_loss import ContrastiveLoss
    ...
  4. Then the structure of the class should be like this:

from .base_metric_loss_function import BaseMetricLossFunction

class CircleLoss(BaseMetricLossFunction):
    def __init__(self, your_args, **kwargs):
        super().__init__(**kwargs)
        # init stuff

    def compute_loss(self, embeddings, labels, indices_tuple):
        # We can discuss in the pull request what to do with indices_tuple
        # calculate and return the loss

Can you open a pull request? We can discuss it further there.