KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.96k stars 657 forks source link

Allow additional parameters in `compute_mat` #697

Open celsofranssa opened 4 months ago

celsofranssa commented 4 months ago

Hello,

The following code snippet shows a custom distance function that scales the simple dot distance with the rewards associated with each embedding.

import torch
from pytorch_metric_learning.distances import BaseDistance

class CustomDistance(BaseDistance):

    def __init__(self, params):
        super().__init__(params, is_inverted=True)
        assert self.is_inverted
        # dict of ref_emb rewards
        self.rewards = params.rewards

    def compute_mat(self, embeddings_ids, embeddings, ref_emb_ids, ref_emb):
        mat =  20 * torch.einsum("ab,cb->ac", embeddings, ref_emb)
        for col, ref_emb_idx in ref_emb_ids.items():
            mat[:, col] *= self.rewards[ref_emb_idx]

However, the loss function call of mat = self.distance(embeddings, ref_emb) not allows calling the overridden method compute_mat(self, embeddings_ids, embeddings, ref_emb_ids, ref_emb) (containing the embeddings and ref_emb ids) required to get the embedding reward and scale the corresponding distance value.

Is there a workaround?

Thank you.

celsofranssa commented 4 months ago

Hello @KevinMusgrave Do you have any suggestions here?

KevinMusgrave commented 4 months ago

Would overriding the forward method work?

https://github.com/KevinMusgrave/pytorch-metric-learning/blob/adfb78ccb275e5bd7a84d2693da41ff3002eed32/src/pytorch_metric_learning/distances/base_distance.py#L26

I guess it depends where embeddings_ids and ref_emb_ids is coming from.

KevinMusgrave commented 4 months ago

Or do you mean you need to change the mat = self.distance(embeddings, ref_emb) in the loss function?

celsofranssa commented 4 months ago

Or do you mean you need to change the mat = self.distance(embeddings, ref_emb) in the loss function?

Would overriding the forward method work?

Exactly @KevinMusgrave,

I am already using the very nice NTXentLoss. Recently, I discovered that scaling the simple dot distance with the rewards associated with each embedding is beneficial for my research task. At this point, I need to access the embedding identifier because each embedding has its associated reward.

KevinMusgrave commented 4 months ago

I see, so you also need to modify NTXentLoss to take in those ids?

celsofranssa commented 4 months ago

I see, so you also need to modify NTXentLoss to take in those ids?

If there is no other overriding approach, yes.

KevinMusgrave commented 4 months ago

Yeah I don't think there's another approach. The only other way would be to set some attribute of your custom distance object before computing the loss.

dist_fn = CustomDistance()
loss_fn = NTXentLoss(distance=dist_fn)

...
dist_fn.curr_ref_ids = ref_ids
loss = loss_fn(...)

# inside dist_fn refer to self.curr_ref_ids
celsofranssa commented 4 months ago

Yeah I don't think there's another approach. The only other way would be to set some attribute of your custom distance object before computing the loss.

dist_fn = CustomDistance()
loss_fn = NTXentLoss(distance=dist_fn)

...
dist_fn.curr_ref_ids = ref_ids
loss = loss_fn(...)

# inside dist_fn refer to self.curr_ref_ids

I see. For a feature release, it would be great if the mat = self.distance(embeddings, ref_emb) function accepts additional parameters to leverage some custom distance implementations.

celsofranssa commented 1 month ago

Hello,

Is there any progress here?

KevinMusgrave commented 1 month ago

Sorry, no progress yet