Open kachayev opened 3 years ago
That would be great!
I'm wondering if the hierarchical tree should be implemented as a sampler, rather than inside a loss function. Something like:
class HierarchicalTree(torch.utils.data.Sampler):
def __init__(self, dataset, model, distance_metric):
self.dataset = dataset
self.model = model
self.distance_metric = distance_metric
self.tree = ...
def __iter__(self):
# update self.tree
# return indices for sampling
The loss function would have to accept a reference to the tree as input
class HierarchicalTripletLoss(BaseMetricLossFunction):
def __init__(self, tree):
self.tree = tree
def compute_loss(self, embeddings, labels, indices_tuple):
# compute margins based on self.tree
HierarchicalTripletLoss can probably extend TripletMarginLoss:
class HierarchicalTripletLoss(TripletMarginLoss):
def __init__(self, tree):
self.tree = tree
def get_margins_from_tree(self, la, ln):
# use self.tree to compute margins based on anchor and negative labels
def compute_loss(self, embeddings, labels, indices_tuple):
indices_tuple = lmu.convert_to_triplets(
indices_tuple, labels, t_per_anchor=self.triplets_per_anchor
)
a, _, n = indices_tuple
la, ln = labels[a], labels[n]
self.margin = self.get_margins_from_tree(la, ln)
return super().compute_loss(embeddings, labels, indices_tuple)
Yeah, I think extending TripletMarginLoss
is a good choice (also, semantically it makes a lot of sense).
I will submit a PR as soon as first draft is ready!
@kachayev , out of curiosity, are you still working on this PR? If not, I'd be interested in taking a look at it.
@ibebrett Not as of now, feel free to chime in!
Hi, I would like to use hierarchical triplet loss for my own project and was simply wondering if anyone was able to translate it into code? Thanks in advance!
Are the codes available now?
No
Hey!
Any objections adding Hierarchical Triplet Loss (HTL), described here? Happy to work on a Pull Request.