KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
6.02k stars 658 forks source link

Eerror of Centroid Triplet Loss (version1.2.0) #451

Open fjssharpsword opened 2 years ago

fjssharpsword commented 2 years ago

In version 1.2.0, Centroid Triplet Loss is unstable. When the batch size is small, the below error always occurs. If larger batch size, it is relatively be relieved.

File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/base_metric_loss_function.py", line 38, in forward embeddings, labels, indices_tuple, ref_emb, ref_labels File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in compute_loss indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in <listcomp> indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] RuntimeError: shape '[6, -1]' is invalid for input of size 70330

KevinMusgrave commented 2 years ago

Could you provide a snippet of code that produces the error?

fjssharpsword commented 2 years ago

Could you provide a snippet of code that produces the error?

Thank you for your attention! Our code is as below, and loss_func is:

CentroidTripletLoss(margin=0.05, swap=False, smooth_loss=False, triplets_per_anchor="all")

your training loop

for i, (data, labels) in enumerate(dataloader): optimizer.zero_grad() embeddings = model(data) loss = loss_func(embeddings, labels) loss.backward() optimizer.step()`

Now, we find this issue is non-related to batch size and may be caused by the that the length of input x can not be divided by the length of the one_labels list. Hence, we now try to tune the length returned by lmu.get_all_triplets_indices or set the parameter triplets_per_anchor.

fjssharpsword commented 2 years ago

We temporarily add an assertion in the compute_loss function for losses.CentroidTripletLoss, as follows:

make only query vectors be anchor vectors

indices_tuple = [x[: len(x) // 3] + starting_idx for x in indices_tuple]

added by Jason.Fang

_remainder = len(indices_tuple[0])%len(onelabels) #can not be divisible

quotient = len(indices_tuple[0])//len(one_labels)

_indices_tuple = [x[: len(indices_tuple[0])-remainder] for x in indicestuple]

make only pos_centroids be postive examples

indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] indices_tuple = [x.chunk(2, dim=1)[0] for x in indices_tuple]

make only neg_centroids be negative examples

indices_tuple = [x.chunk(len(one_labels), dim=1)[-1].flatten() for x in indices_tuple]

This is a temporary solution and may negatively contribute to the results due to truncated triplet samples. We look forward to an effective solution.

KevinMusgrave commented 2 years ago

@cwkeam Do you remember how this part works? https://github.com/KevinMusgrave/pytorch-metric-learning/blob/58247798ca9bf62ff49874e5cd07c41424e64fe9/src/pytorch_metric_learning/losses/centroid_triplet_loss.py#L124

It fails this test that I've added to the dev branch: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/3930db293cd2c71181ffb776cc71b5e8a7d1f502/tests/losses/test_centroid_triplet_loss.py#L20-L24

The error message:

Traceback (most recent call last):
  File "pytorch-metric-learning/tests/losses/test_centroid_triplet_loss.py", line 24, in test_indices_tuple_failure
    loss_fn(embeddings, labels)
  File "python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "pytorch-metric-learning/src/pytorch_metric_learning/losses/base_metric_loss_function.py", line 37, in forward
    loss_dict = self.compute_loss(
  File "pytorch-metric-learning/src/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in compute_loss
    indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple]
  File "pytorch-metric-learning/src/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in <listcomp>
    indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple]
RuntimeError: shape '[3, -1]' is invalid for input of size 52

Command to run the test: python -m unittest tests.losses.test_centroid_triplet_loss.TestCentroidTripletLoss.test_indices_tuple_failure

cwkeam commented 2 years ago

@KevinMusgrave @fjssharpsword thanks for pointing this out I'll take a look

cwkeam commented 2 years ago

@KevinMusgrave @fjssharpsword

After reviewing the code and the old discussions about this loss function, the reason why the above test fails is basically just because there's a class with only one embedding.

The code comments & input param checks & documentations haven't been clear on this, and I've adjusted the code locally to improve upon that.

If you look at an old discussion post here, though it isn't exactly about this, you can see Kevin say:

labels = [0,0,0,1,1,1]
The positive centroid is the average of embeddings 1 and 2, 
and the anchor is embedding 0, so that makes sense.

In the above example with [0,0,0,1,1,1], there are 6 embeddings, [0,1,2] in class A and [3,4,5] in class B. Let's compute the loss for each embedding:

# 1
anchor = embs[0]
positive_centroid = avg(embs[1], embs[2])
negative_centroid = avg(embs[3:6])

# 2
anchor = embs[1]
positive_centroid = avg(embs[0], embs[2])
negative_centroid = avg(embs[3:6])

...

The reason why we take out the anchor embedding in the positive centroid is because, it's kind of perverse to think about the distance of an embedding to an embedding its included in. It's also implemented as such in the code I referenced for this PR.

So the above process clearly breaks for cases with only one sample per class (i.e. avgs([])).

It's not really about batch sizes or anything, because test cases with arbitrarily imbalanced embedding sets with any sort of organization passes successfully, and fails only if the embeddings imply a class with just one sample.

@KevinMusgrave I'm leaning towards raising a ValueError for cases like this. Or do you think I should make it supported somehow? If forced to be supported the loss for the lone anchor would just be defined as the negative centroid part.

Let me know what you think!

KevinMusgrave commented 2 years ago

Thanks @cwkeam, I think raising a ValueError makes sense.

KevinMusgrave commented 2 years ago

In v1.3.0, a ValueError will be raised if any of the labels have only 1 embedding.

amirdnc commented 2 years ago

I get a similar error when each label has more than 1 embedding. it seems that the discussed line of code:

https://github.com/KevinMusgrave/pytorch-metric-learning/blob/58247798ca9bf62ff49874e5cd07c41424e64fe9/src/pytorch_metric_learning/losses/centroid_triplet_loss.py#L124

assumes that the number of labels is a divider of len(indices_tuple), which isn't always the case

KevinMusgrave commented 2 years ago

@amirdnc Could you provide a snippet of code so that we can reproduce the error? I think the labels array would be sufficient.

amirdnc commented 2 years ago

Sure. Try this:

    loss_func = losses.CentroidTripletLoss()
    l = torch.tensor([228, 228, 228, 228, 228, 228, 228, 228, 228, 228,  83,  83,  83,  83,
         83,  83,  83,  83,  83,  83, 200, 200, 200, 200, 200, 200, 200, 200,
        200, 200, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 250, 250,
        250, 250, 250, 250, 250, 250, 250, 250,  13,  13,  13,  13,  13,  13,
         13,  13,  13,  13, 124, 124, 124, 124, 124, 124, 124, 124, 124, 124,
         86,  86,  86,  86,  86,  86,  86,  86,  86,  86, 179, 179, 179, 179,
        179, 179, 179, 179, 179, 179, 338, 338, 338, 338, 338, 338, 338, 338,
        338, 338, 146, 146, 146, 146, 146, 146, 146, 146, 146, 146, 225, 225,
        225, 225, 225, 225, 225, 225, 225, 225,  22,  22,  22,  22,  22,  22,
         22,  22,  22,  22,  73,  73,  73,  73,  73,  73,  73,  73,  73,  73,
        300, 300, 300, 300, 300, 300, 300, 300, 300, 300, 431, 431, 431, 431,
        431, 431, 431, 431, 431, 431, 349, 349, 349, 349, 349, 349, 349, 349,
        349, 349, 305, 305, 305, 305, 305, 305, 305, 305, 305, 305, 117, 117,
        117, 117, 117, 117, 117, 117, 117, 117,  56,  56, 206, 206, 206, 206,
        206, 206, 206, 206, 206, 206, 293, 293, 293, 293, 293, 293, 293, 293,
        293, 293, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 135, 135,
        135, 135, 135, 135, 135, 135, 135, 135, 394, 394, 394, 394, 394, 394,
        394, 394, 394, 394, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126,
        366, 366, 366, 366, 366, 366, 366, 366, 366, 366,  59,  59,  59,  59,
         59,  59,  59,  59,  59,  59, 235, 235, 235, 235, 235, 235, 235, 235,
        235, 235, 352, 352, 352, 352, 352, 352, 352, 352, 352, 352, 193, 193,
        193, 193, 193, 193, 193, 193, 193, 193, 341, 341, 341, 341, 341, 341,
        341, 341, 341, 341])

    a = torch.rand([l.size(0),10])

    print(loss_func(a, l))
KevinMusgrave commented 2 years ago

Thanks @amirdnc

yankungou commented 2 years ago

Hi @KevinMusgrave, thank you for your incredible implementation. I met the same issue as @amirdnc. Is this issue be fixed yet? Thanks!

KevinMusgrave commented 2 years ago

@YK711 Sorry I haven't gotten around to this yet.

KevinMusgrave commented 2 years ago

Maybe @cwkeam has some free time to look into this :smile:

yankungou commented 2 years ago

Thanks! @KevinMusgrave

yankungou commented 2 years ago

Hi @cwkeam, I met the same issue. Do you have any idea to fix it? Thanks!

Sure. Try this:

    loss_func = losses.CentroidTripletLoss()
    l = torch.tensor([228, 228, 228, 228, 228, 228, 228, 228, 228, 228,  83,  83,  83,  83,
         83,  83,  83,  83,  83,  83, 200, 200, 200, 200, 200, 200, 200, 200,
        200, 200, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 250, 250,
        250, 250, 250, 250, 250, 250, 250, 250,  13,  13,  13,  13,  13,  13,
         13,  13,  13,  13, 124, 124, 124, 124, 124, 124, 124, 124, 124, 124,
         86,  86,  86,  86,  86,  86,  86,  86,  86,  86, 179, 179, 179, 179,
        179, 179, 179, 179, 179, 179, 338, 338, 338, 338, 338, 338, 338, 338,
        338, 338, 146, 146, 146, 146, 146, 146, 146, 146, 146, 146, 225, 225,
        225, 225, 225, 225, 225, 225, 225, 225,  22,  22,  22,  22,  22,  22,
         22,  22,  22,  22,  73,  73,  73,  73,  73,  73,  73,  73,  73,  73,
        300, 300, 300, 300, 300, 300, 300, 300, 300, 300, 431, 431, 431, 431,
        431, 431, 431, 431, 431, 431, 349, 349, 349, 349, 349, 349, 349, 349,
        349, 349, 305, 305, 305, 305, 305, 305, 305, 305, 305, 305, 117, 117,
        117, 117, 117, 117, 117, 117, 117, 117,  56,  56, 206, 206, 206, 206,
        206, 206, 206, 206, 206, 206, 293, 293, 293, 293, 293, 293, 293, 293,
        293, 293, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 135, 135,
        135, 135, 135, 135, 135, 135, 135, 135, 394, 394, 394, 394, 394, 394,
        394, 394, 394, 394, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126,
        366, 366, 366, 366, 366, 366, 366, 366, 366, 366,  59,  59,  59,  59,
         59,  59,  59,  59,  59,  59, 235, 235, 235, 235, 235, 235, 235, 235,
        235, 235, 352, 352, 352, 352, 352, 352, 352, 352, 352, 352, 193, 193,
        193, 193, 193, 193, 193, 193, 193, 193, 341, 341, 341, 341, 341, 341,
        341, 341, 341, 341])

    a = torch.rand([l.size(0),10])

    print(loss_func(a, l))
KevinMusgrave commented 1 year ago

I've removed this loss function from v2.0.0 as I don't have time to figure out the bug. If someone figures it out, please open a pull request.

pgrosjean commented 1 year ago

@KevinMusgrave @cwkeam

If you use the MPerClassSampler with the batch size divisible by M, then there is no issue with the code as is, so if people want a quick fix, I suggest this. But I also figured out the bug in the code. The problem only arises when within each batch, there are differences in the number of class instances.

The bug is In Line 176 of centroid_triplet_loss.py, where you append the index 0 to your query_indices list in the case where instance_idx > len(class_insts); this is in order to handle the case where there is a class with less than the maximum number of class instances in a batch. However, because you are simply appending 0, this indexes into the class at the zeroth index in your class list. So when you call get_matches_and_diffs on the labels_concat you get a matches matrix that looks like this

image

True values that are off the diagonal correspond to the class at the index 0.

If instead of appending 0 to query_indices, you instead append class_insts[0], then this fixes the issue, and you get a matches matrix from get_matches_and_diffs that looks like this image

I have not done rigorous testing, but this does fix the issue that @yankungo and @amirdnc were having with this loss.

The debugged version of create_masks_train function

def create_masks_train(self, class_labels):
    labels_dict = defaultdict(list)
    class_labels = class_labels.detach().cpu().numpy()
    for idx, pid in enumerate(class_labels):
        labels_dict[pid].append(idx)

    unique_classes = list(labels_dict.keys())
    labels_list = list(labels_dict.values())
    lens_list = [len(item) for item in labels_list]
    lens_list_cs = np.cumsum(lens_list)

    M = max(len(instances) for instances in labels_list)
    P = len(unique_classes)

    query_indices = []
    class_masks = torch.zeros((P, len(class_labels)), dtype=bool)
    masks = torch.zeros((M * P, len(class_labels)), dtype=bool)
    for class_idx, class_insts in enumerate(labels_list):
        class_masks[class_idx, class_insts] = 1
        for instance_idx in range(M):
            matrix_idx = class_idx * M + instance_idx
            if instance_idx < len(class_insts):
                query_indices.append(class_insts[instance_idx])
                ones = class_insts[:instance_idx] + class_insts[instance_idx + 1 :]
                masks[matrix_idx, ones] = 1
            else:
                query_indices.append(class_insts[0])
    return masks, class_masks, labels_list, query_indices

Code that shows that the problem is solved with this update

from collections import defaultdict

import numpy as np
import torch

from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.losses import BaseMetricLossFunction
from pytorch_metric_learning.losses import TripletMarginLoss

def concat_indices_tuple(x):
    return [torch.cat(y) for y in zip(*x)]

class CentroidTripletLoss(BaseMetricLossFunction):
    def __init__(
        self,
        margin=0.05,
        swap=False,
        smooth_loss=False,
        triplets_per_anchor="all",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.triplet_loss = TripletMarginLoss(
            margin=margin,
            swap=swap,
            smooth_loss=smooth_loss,
            triplets_per_anchor=triplets_per_anchor,
            **kwargs
        )

    def compute_loss(
        self, embeddings, labels, indices_tuple=None, ref_emb=None, ref_labels=None
    ):
        c_f.indices_tuple_not_supported(indices_tuple)
        c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
        """
        "During training stage each mini-batch contains 𝑃 distinct item
        classes with 𝑀 samples per class, resulting in batch size of 𝑃 × 𝑀."
        """
        masks, class_masks, labels_list, query_indices = self.create_masks_train(labels)

        P = len(labels_list)
        M = max([len(instances) for instances in labels_list])
        DIM = embeddings.size(-1)

        """
        "...each sample from S𝑘 is used as a query 𝑞𝑘 and the rest 
        𝑀 −1 samples are used to build a prototype centroid"
        i.e. for each class k of M items, we make M pairs of (query, centroid),
        making a total of P*M total pairs.
        masks = (M*P x len(embeddings)) matrix
        labels_list[i] = indicies of embeddings belonging to ith class
        centroids_emd.shape == (M*P, DIM)
        i.e.    centroids_emb[0] == centroid vector for 0th class, where the first embedding is the query vector
                centroids_emb[1] == centroid vector for 0th class, where the second embedding is the query vector
                centroids_emb[M+1] == centroid vector for 1th class, where the first embedding is the query vector
        """

        masks_float = masks.type(embeddings.type()).to(embeddings.device)
        class_masks_float = class_masks.type(embeddings.type()).to(embeddings.device)
        inst_counts = masks_float.sum(-1)
        class_inst_counts = class_masks_float.sum(-1)

        valid_mask = inst_counts > 0
        padded = masks_float.unsqueeze(-1) * embeddings.unsqueeze(0)
        class_padded = class_masks_float.unsqueeze(-1) * embeddings.unsqueeze(0)

        positive_centroids_emb = padded.sum(-2) / inst_counts.masked_fill(
            inst_counts == 0, 1
        ).unsqueeze(-1)

        negative_centroids_emb = class_padded.sum(-2) / class_inst_counts.masked_fill(
            class_inst_counts == 0, 1
        ).unsqueeze(-1)

        query_indices = torch.tensor(query_indices).to(embeddings.device)
        query_embeddings = embeddings.index_select(0, query_indices)
        query_labels = labels.index_select(0, query_indices)
        assert positive_centroids_emb.size() == (M * P, DIM)
        assert negative_centroids_emb.size() == (P, DIM)
        assert query_embeddings.size() == (M * P, DIM)

        query_indices = query_indices.view((P, M)).transpose(0, 1)
        query_embeddings = query_embeddings.view((P, M, -1)).transpose(0, 1)
        query_labels = query_labels.view((P, M)).transpose(0, 1)
        positive_centroids_emb = positive_centroids_emb.view((P, M, -1)).transpose(0, 1)
        valid_mask = valid_mask.view((P, M)).transpose(0, 1)

        labels_collect = []
        embeddings_collect = []
        tuple_indices_collect = []
        starting_idx = 0
        for inst_idx in range(M):
            one_mask = valid_mask[inst_idx]
            if torch.sum(one_mask) > 1:
                anchors = query_embeddings[inst_idx][one_mask]
                pos_centroids = positive_centroids_emb[inst_idx][one_mask]
                one_labels = query_labels[inst_idx][one_mask]

                embeddings_concat = torch.cat(
                    (anchors, pos_centroids, negative_centroids_emb)
                )
                labels_concat = torch.cat(
                    (one_labels, one_labels, query_labels[inst_idx])
                )
                indices_tuple = get_all_triplets_indices(labels_concat)

                """
                Right now indices tuple considers all embeddings in
                embeddings_concat as anchors, pos_example, neg_examples.
                1. make only query vectors be anchor vectors
                2. make pos_centroids be only used as a positive example
                3. negative as so
                """
                # make only query vectors be anchor vectors
                indices_tuple = [x[: len(x) // 3] + starting_idx for x in indices_tuple]

                # make only pos_centroids be postive examples
                indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple]
                indices_tuple = [x.chunk(2, dim=1)[0] for x in indices_tuple]

                # make only neg_centroids be negative examples
                indices_tuple = [
                    x.chunk(len(one_labels), dim=1)[-1].flatten() for x in indices_tuple
                ]

                tuple_indices_collect.append(indices_tuple)
                embeddings_collect.append(embeddings_concat)
                labels_collect.append(labels_concat)
                starting_idx += len(labels_concat)

        indices_tuple = concat_indices_tuple(tuple_indices_collect)

        if len(indices_tuple) == 0:
            return self.zero_losses()

        final_embeddings = torch.cat(embeddings_collect)
        final_labels = torch.cat(labels_collect)

        loss = self.triplet_loss.compute_loss(
            final_embeddings, final_labels, indices_tuple, ref_emb=None, ref_labels=None
        )
        return loss

    def create_masks_train(self, class_labels):
        labels_dict = defaultdict(list)
        class_labels = class_labels.detach().cpu().numpy()
        for idx, pid in enumerate(class_labels):
            labels_dict[pid].append(idx)

        unique_classes = list(labels_dict.keys())
        labels_list = list(labels_dict.values())
        lens_list = [len(item) for item in labels_list]
        lens_list_cs = np.cumsum(lens_list)

        M = max(len(instances) for instances in labels_list)
        P = len(unique_classes)

        query_indices = []
        class_masks = torch.zeros((P, len(class_labels)), dtype=bool)
        masks = torch.zeros((M * P, len(class_labels)), dtype=bool)
        for class_idx, class_insts in enumerate(labels_list):
            class_masks[class_idx, class_insts] = 1
            for instance_idx in range(M):
                matrix_idx = class_idx * M + instance_idx
                if instance_idx < len(class_insts):
                    query_indices.append(class_insts[instance_idx])
                    ones = class_insts[:instance_idx] + class_insts[instance_idx + 1 :]
                    masks[matrix_idx, ones] = 1
                else:
                    query_indices.append(class_insts[0])
        return masks, class_masks, labels_list, query_indices

def get_all_triplets_indices(labels, ref_labels=None):
    matches, diffs = get_matches_and_diffs(labels, ref_labels)
    triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
    return torch.where(triplets)

def get_matches_and_diffs(labels, ref_labels=None):
    if ref_labels is None:
        ref_labels = labels
    labels1 = labels.unsqueeze(1)
    labels2 = ref_labels.unsqueeze(0)
    matches = (labels1 == labels2).byte()
    diffs = matches ^ 1
    if ref_labels is labels:
        matches.fill_diagonal_(0)
    return matches, diffs

loss_func = CentroidTripletLoss()
l = torch.tensor([228, 228, 228, 228, 228, 228, 228, 228, 228, 228,  83,  83,  83,  83,
     83,  83,  83,  83,  83,  83, 200, 200, 200, 200, 200, 200, 200, 200,
    200, 200, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 250, 250,
    250, 250, 250, 250, 250, 250, 250, 250,  13,  13,  13,  13,  13,  13,
     13,  13,  13,  13, 124, 124, 124, 124, 124, 124, 124, 124, 124, 124,
     86,  86,  86,  86,  86,  86,  86,  86,  86,  86, 179, 179, 179, 179,
    179, 179, 179, 179, 179, 179, 338, 338, 338, 338, 338, 338, 338, 338,
    338, 338, 146, 146, 146, 146, 146, 146, 146, 146, 146, 146, 225, 225,
    225, 225, 225, 225, 225, 225, 225, 225,  22,  22,  22,  22,  22,  22,
     22,  22,  22,  22,  73,  73,  73,  73,  73,  73,  73,  73,  73,  73,
    300, 300, 300, 300, 300, 300, 300, 300, 300, 300, 431, 431, 431, 431,
    431, 431, 431, 431, 431, 431, 349, 349, 349, 349, 349, 349, 349, 349,
    349, 349, 305, 305, 305, 305, 305, 305, 305, 305, 305, 305, 117, 117,
    117, 117, 117, 117, 117, 117, 117, 117,  56, 56, 206, 206, 206, 206,
    206, 206, 206, 206, 206, 206, 1, 1, 293, 293, 293, 293, 293, 293, 293, 293,
    293, 293, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 135, 135,
    135, 135, 135, 135, 135, 135, 135, 135, 394, 394, 394, 394, 394, 394,
    394, 394, 394, 394, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126,
    366, 366, 366, 366, 366, 366, 366, 366, 366, 366,  59,  59,  59,  59,
     59,  59,  59,  59,  59,  59, 235, 235, 235, 235, 235, 235, 235, 235,
    235, 235, 352, 352, 352, 352, 352, 352, 352, 352, 352, 352, 193, 193,
    193, 193, 193, 193, 193, 193, 193, 193, 341, 341, 341, 341, 341, 341,
    341, 341, 341, 341])

a = torch.rand([l.size(0),10])

print(loss_func(a, l))

Hopefully, you can bring back this loss in the next update. Let me know if you want me to submit a PR or anything. Thanks!

KevinMusgrave commented 1 year ago

@pgrosjean Sorry for the delayed response. I've been extra busy the past few weeks.

Thank you for investigating this and providing such a detailed explanation with plots!

Yes, please open a PR. Here's the old unittest file you can copy-paste for now: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/c0cc5de16134704d6b14ccf5ac866e134121e032/tests/losses/test_centroid_triplet_loss.py