KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.94k stars 658 forks source link

Hierarchical Triplet Loss #314

Open kachayev opened 3 years ago

kachayev commented 3 years ago

Hey!

Any objections adding Hierarchical Triplet Loss (HTL), described here? Happy to work on a Pull Request.

KevinMusgrave commented 3 years ago

That would be great!

I'm wondering if the hierarchical tree should be implemented as a sampler, rather than inside a loss function. Something like:

class HierarchicalTree(torch.utils.data.Sampler):
    def __init__(self, dataset, model, distance_metric):
        self.dataset = dataset
        self.model = model
        self.distance_metric = distance_metric
        self.tree = ...

    def __iter__(self):
        # update self.tree
        # return indices for sampling

The loss function would have to accept a reference to the tree as input

class HierarchicalTripletLoss(BaseMetricLossFunction):
    def __init__(self, tree):
        self.tree = tree

    def compute_loss(self, embeddings, labels, indices_tuple):
        # compute margins based on self.tree
KevinMusgrave commented 3 years ago

HierarchicalTripletLoss can probably extend TripletMarginLoss:

class HierarchicalTripletLoss(TripletMarginLoss):
    def __init__(self, tree):
        self.tree = tree

    def get_margins_from_tree(self, la, ln):
        # use self.tree to compute margins based on anchor and negative labels

    def compute_loss(self, embeddings, labels, indices_tuple):
        indices_tuple = lmu.convert_to_triplets(
            indices_tuple, labels, t_per_anchor=self.triplets_per_anchor
        )
        a, _, n = indices_tuple
        la, ln = labels[a], labels[n]
        self.margin = self.get_margins_from_tree(la, ln)
        return super().compute_loss(embeddings, labels, indices_tuple)
kachayev commented 3 years ago

Yeah, I think extending TripletMarginLoss is a good choice (also, semantically it makes a lot of sense).

I will submit a PR as soon as first draft is ready!

ibebrett commented 3 years ago

@kachayev , out of curiosity, are you still working on this PR? If not, I'd be interested in taking a look at it.

kachayev commented 3 years ago

@ibebrett Not as of now, feel free to chime in!

majam10 commented 3 years ago

Hi, I would like to use hierarchical triplet loss for my own project and was simply wondering if anyone was able to translate it into code? Thanks in advance!

Acura-bit commented 1 year ago

Are the codes available now?

KevinMusgrave commented 1 year ago

No