Open fjssharpsword opened 2 years ago
Could you provide a snippet of code that produces the error?
Could you provide a snippet of code that produces the error?
Thank you for your attention! Our code is as below, and loss_func is:
CentroidTripletLoss(margin=0.05, swap=False, smooth_loss=False, triplets_per_anchor="all")
your training loop
for i, (data, labels) in enumerate(dataloader): optimizer.zero_grad() embeddings = model(data) loss = loss_func(embeddings, labels) loss.backward() optimizer.step()`
Now, we find this issue is non-related to batch size and may be caused by the that the length of input x can not be divided by the length of the one_labels list. Hence, we now try to tune the length returned by lmu.get_all_triplets_indices or set the parameter triplets_per_anchor.
We temporarily add an assertion in the compute_loss function for losses.CentroidTripletLoss, as follows:
make only query vectors be anchor vectors
indices_tuple = [x[: len(x) // 3] + starting_idx for x in indices_tuple]
added by Jason.Fang
_remainder = len(indices_tuple[0])%len(onelabels) #can not be divisible
quotient = len(indices_tuple[0])//len(one_labels)
_indices_tuple = [x[: len(indices_tuple[0])-remainder] for x in indicestuple]
make only pos_centroids be postive examples
indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] indices_tuple = [x.chunk(2, dim=1)[0] for x in indices_tuple]
make only neg_centroids be negative examples
indices_tuple = [x.chunk(len(one_labels), dim=1)[-1].flatten() for x in indices_tuple]
This is a temporary solution and may negatively contribute to the results due to truncated triplet samples. We look forward to an effective solution.
@cwkeam Do you remember how this part works? https://github.com/KevinMusgrave/pytorch-metric-learning/blob/58247798ca9bf62ff49874e5cd07c41424e64fe9/src/pytorch_metric_learning/losses/centroid_triplet_loss.py#L124
It fails this test that I've added to the dev branch: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/3930db293cd2c71181ffb776cc71b5e8a7d1f502/tests/losses/test_centroid_triplet_loss.py#L20-L24
The error message:
Traceback (most recent call last):
File "pytorch-metric-learning/tests/losses/test_centroid_triplet_loss.py", line 24, in test_indices_tuple_failure
loss_fn(embeddings, labels)
File "python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "pytorch-metric-learning/src/pytorch_metric_learning/losses/base_metric_loss_function.py", line 37, in forward
loss_dict = self.compute_loss(
File "pytorch-metric-learning/src/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in compute_loss
indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple]
File "pytorch-metric-learning/src/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in <listcomp>
indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple]
RuntimeError: shape '[3, -1]' is invalid for input of size 52
Command to run the test:
python -m unittest tests.losses.test_centroid_triplet_loss.TestCentroidTripletLoss.test_indices_tuple_failure
@KevinMusgrave @fjssharpsword thanks for pointing this out I'll take a look
@KevinMusgrave @fjssharpsword
After reviewing the code and the old discussions about this loss function, the reason why the above test fails is basically just because there's a class with only one embedding.
The code comments & input param checks & documentations haven't been clear on this, and I've adjusted the code locally to improve upon that.
If you look at an old discussion post here, though it isn't exactly about this, you can see Kevin say:
labels = [0,0,0,1,1,1]
The positive centroid is the average of embeddings 1 and 2,
and the anchor is embedding 0, so that makes sense.
In the above example with [0,0,0,1,1,1]
, there are 6 embeddings, [0,1,2]
in class A and [3,4,5]
in class B. Let's compute the loss for each embedding:
# 1
anchor = embs[0]
positive_centroid = avg(embs[1], embs[2])
negative_centroid = avg(embs[3:6])
# 2
anchor = embs[1]
positive_centroid = avg(embs[0], embs[2])
negative_centroid = avg(embs[3:6])
...
The reason why we take out the anchor embedding in the positive centroid
is because, it's kind of perverse to think about the distance of an embedding to an embedding its included in. It's also implemented as such in the code I referenced for this PR.
So the above process clearly breaks for cases with only one sample per class (i.e. avgs([])
).
It's not really about batch sizes or anything, because test cases with arbitrarily imbalanced embedding sets with any sort of organization passes successfully, and fails only if the embeddings imply a class with just one sample.
@KevinMusgrave I'm leaning towards raising a ValueError
for cases like this. Or do you think I should make it supported somehow? If forced to be supported the loss for the lone anchor would just be defined as the negative centroid part.
Let me know what you think!
Thanks @cwkeam, I think raising a ValueError makes sense.
In v1.3.0, a ValueError will be raised if any of the labels have only 1 embedding.
I get a similar error when each label has more than 1 embedding. it seems that the discussed line of code:
assumes that the number of labels is a divider of len(indices_tuple), which isn't always the case
@amirdnc Could you provide a snippet of code so that we can reproduce the error? I think the labels array would be sufficient.
Sure. Try this:
loss_func = losses.CentroidTripletLoss()
l = torch.tensor([228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 83, 83, 83, 83,
83, 83, 83, 83, 83, 83, 200, 200, 200, 200, 200, 200, 200, 200,
200, 200, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 250, 250,
250, 250, 250, 250, 250, 250, 250, 250, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 124, 124, 124, 124, 124, 124, 124, 124, 124, 124,
86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 179, 179, 179, 179,
179, 179, 179, 179, 179, 179, 338, 338, 338, 338, 338, 338, 338, 338,
338, 338, 146, 146, 146, 146, 146, 146, 146, 146, 146, 146, 225, 225,
225, 225, 225, 225, 225, 225, 225, 225, 22, 22, 22, 22, 22, 22,
22, 22, 22, 22, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73,
300, 300, 300, 300, 300, 300, 300, 300, 300, 300, 431, 431, 431, 431,
431, 431, 431, 431, 431, 431, 349, 349, 349, 349, 349, 349, 349, 349,
349, 349, 305, 305, 305, 305, 305, 305, 305, 305, 305, 305, 117, 117,
117, 117, 117, 117, 117, 117, 117, 117, 56, 56, 206, 206, 206, 206,
206, 206, 206, 206, 206, 206, 293, 293, 293, 293, 293, 293, 293, 293,
293, 293, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 135, 135,
135, 135, 135, 135, 135, 135, 135, 135, 394, 394, 394, 394, 394, 394,
394, 394, 394, 394, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126,
366, 366, 366, 366, 366, 366, 366, 366, 366, 366, 59, 59, 59, 59,
59, 59, 59, 59, 59, 59, 235, 235, 235, 235, 235, 235, 235, 235,
235, 235, 352, 352, 352, 352, 352, 352, 352, 352, 352, 352, 193, 193,
193, 193, 193, 193, 193, 193, 193, 193, 341, 341, 341, 341, 341, 341,
341, 341, 341, 341])
a = torch.rand([l.size(0),10])
print(loss_func(a, l))
Thanks @amirdnc
Hi @KevinMusgrave, thank you for your incredible implementation. I met the same issue as @amirdnc. Is this issue be fixed yet? Thanks!
@YK711 Sorry I haven't gotten around to this yet.
Maybe @cwkeam has some free time to look into this :smile:
Thanks! @KevinMusgrave
Hi @cwkeam, I met the same issue. Do you have any idea to fix it? Thanks!
Sure. Try this:
loss_func = losses.CentroidTripletLoss() l = torch.tensor([228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 124, 124, 124, 124, 124, 124, 124, 124, 124, 124, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 338, 338, 338, 338, 338, 338, 338, 338, 338, 338, 146, 146, 146, 146, 146, 146, 146, 146, 146, 146, 225, 225, 225, 225, 225, 225, 225, 225, 225, 225, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 300, 300, 300, 300, 300, 300, 300, 300, 300, 300, 431, 431, 431, 431, 431, 431, 431, 431, 431, 431, 349, 349, 349, 349, 349, 349, 349, 349, 349, 349, 305, 305, 305, 305, 305, 305, 305, 305, 305, 305, 117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 56, 56, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 293, 293, 293, 293, 293, 293, 293, 293, 293, 293, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 135, 135, 135, 135, 135, 135, 135, 135, 135, 135, 394, 394, 394, 394, 394, 394, 394, 394, 394, 394, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126, 366, 366, 366, 366, 366, 366, 366, 366, 366, 366, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 352, 352, 352, 352, 352, 352, 352, 352, 352, 352, 193, 193, 193, 193, 193, 193, 193, 193, 193, 193, 341, 341, 341, 341, 341, 341, 341, 341, 341, 341]) a = torch.rand([l.size(0),10]) print(loss_func(a, l))
I've removed this loss function from v2.0.0 as I don't have time to figure out the bug. If someone figures it out, please open a pull request.
@KevinMusgrave @cwkeam
If you use the MPerClassSampler with the batch size divisible by M, then there is no issue with the code as is, so if people want a quick fix, I suggest this. But I also figured out the bug in the code. The problem only arises when within each batch, there are differences in the number of class instances.
The bug is In Line 176 of centroid_triplet_loss.py, where you append the index 0 to your query_indices list in the case where instance_idx > len(class_insts); this is in order to handle the case where there is a class with less than the maximum number of class instances in a batch. However, because you are simply appending 0, this indexes into the class at the zeroth index in your class list. So when you call get_matches_and_diffs on the labels_concat you get a matches matrix that looks like this
True values that are off the diagonal correspond to the class at the index 0.
If instead of appending 0 to query_indices, you instead append class_insts[0], then this fixes the issue, and you get a matches matrix from get_matches_and_diffs that looks like this
I have not done rigorous testing, but this does fix the issue that @yankungo and @amirdnc were having with this loss.
def create_masks_train(self, class_labels):
labels_dict = defaultdict(list)
class_labels = class_labels.detach().cpu().numpy()
for idx, pid in enumerate(class_labels):
labels_dict[pid].append(idx)
unique_classes = list(labels_dict.keys())
labels_list = list(labels_dict.values())
lens_list = [len(item) for item in labels_list]
lens_list_cs = np.cumsum(lens_list)
M = max(len(instances) for instances in labels_list)
P = len(unique_classes)
query_indices = []
class_masks = torch.zeros((P, len(class_labels)), dtype=bool)
masks = torch.zeros((M * P, len(class_labels)), dtype=bool)
for class_idx, class_insts in enumerate(labels_list):
class_masks[class_idx, class_insts] = 1
for instance_idx in range(M):
matrix_idx = class_idx * M + instance_idx
if instance_idx < len(class_insts):
query_indices.append(class_insts[instance_idx])
ones = class_insts[:instance_idx] + class_insts[instance_idx + 1 :]
masks[matrix_idx, ones] = 1
else:
query_indices.append(class_insts[0])
return masks, class_masks, labels_list, query_indices
from collections import defaultdict
import numpy as np
import torch
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.losses import BaseMetricLossFunction
from pytorch_metric_learning.losses import TripletMarginLoss
def concat_indices_tuple(x):
return [torch.cat(y) for y in zip(*x)]
class CentroidTripletLoss(BaseMetricLossFunction):
def __init__(
self,
margin=0.05,
swap=False,
smooth_loss=False,
triplets_per_anchor="all",
**kwargs
):
super().__init__(**kwargs)
self.triplet_loss = TripletMarginLoss(
margin=margin,
swap=swap,
smooth_loss=smooth_loss,
triplets_per_anchor=triplets_per_anchor,
**kwargs
)
def compute_loss(
self, embeddings, labels, indices_tuple=None, ref_emb=None, ref_labels=None
):
c_f.indices_tuple_not_supported(indices_tuple)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
"""
"During training stage each mini-batch contains 𝑃 distinct item
classes with 𝑀 samples per class, resulting in batch size of 𝑃 × 𝑀."
"""
masks, class_masks, labels_list, query_indices = self.create_masks_train(labels)
P = len(labels_list)
M = max([len(instances) for instances in labels_list])
DIM = embeddings.size(-1)
"""
"...each sample from S𝑘 is used as a query 𝑞𝑘 and the rest
𝑀 −1 samples are used to build a prototype centroid"
i.e. for each class k of M items, we make M pairs of (query, centroid),
making a total of P*M total pairs.
masks = (M*P x len(embeddings)) matrix
labels_list[i] = indicies of embeddings belonging to ith class
centroids_emd.shape == (M*P, DIM)
i.e. centroids_emb[0] == centroid vector for 0th class, where the first embedding is the query vector
centroids_emb[1] == centroid vector for 0th class, where the second embedding is the query vector
centroids_emb[M+1] == centroid vector for 1th class, where the first embedding is the query vector
"""
masks_float = masks.type(embeddings.type()).to(embeddings.device)
class_masks_float = class_masks.type(embeddings.type()).to(embeddings.device)
inst_counts = masks_float.sum(-1)
class_inst_counts = class_masks_float.sum(-1)
valid_mask = inst_counts > 0
padded = masks_float.unsqueeze(-1) * embeddings.unsqueeze(0)
class_padded = class_masks_float.unsqueeze(-1) * embeddings.unsqueeze(0)
positive_centroids_emb = padded.sum(-2) / inst_counts.masked_fill(
inst_counts == 0, 1
).unsqueeze(-1)
negative_centroids_emb = class_padded.sum(-2) / class_inst_counts.masked_fill(
class_inst_counts == 0, 1
).unsqueeze(-1)
query_indices = torch.tensor(query_indices).to(embeddings.device)
query_embeddings = embeddings.index_select(0, query_indices)
query_labels = labels.index_select(0, query_indices)
assert positive_centroids_emb.size() == (M * P, DIM)
assert negative_centroids_emb.size() == (P, DIM)
assert query_embeddings.size() == (M * P, DIM)
query_indices = query_indices.view((P, M)).transpose(0, 1)
query_embeddings = query_embeddings.view((P, M, -1)).transpose(0, 1)
query_labels = query_labels.view((P, M)).transpose(0, 1)
positive_centroids_emb = positive_centroids_emb.view((P, M, -1)).transpose(0, 1)
valid_mask = valid_mask.view((P, M)).transpose(0, 1)
labels_collect = []
embeddings_collect = []
tuple_indices_collect = []
starting_idx = 0
for inst_idx in range(M):
one_mask = valid_mask[inst_idx]
if torch.sum(one_mask) > 1:
anchors = query_embeddings[inst_idx][one_mask]
pos_centroids = positive_centroids_emb[inst_idx][one_mask]
one_labels = query_labels[inst_idx][one_mask]
embeddings_concat = torch.cat(
(anchors, pos_centroids, negative_centroids_emb)
)
labels_concat = torch.cat(
(one_labels, one_labels, query_labels[inst_idx])
)
indices_tuple = get_all_triplets_indices(labels_concat)
"""
Right now indices tuple considers all embeddings in
embeddings_concat as anchors, pos_example, neg_examples.
1. make only query vectors be anchor vectors
2. make pos_centroids be only used as a positive example
3. negative as so
"""
# make only query vectors be anchor vectors
indices_tuple = [x[: len(x) // 3] + starting_idx for x in indices_tuple]
# make only pos_centroids be postive examples
indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple]
indices_tuple = [x.chunk(2, dim=1)[0] for x in indices_tuple]
# make only neg_centroids be negative examples
indices_tuple = [
x.chunk(len(one_labels), dim=1)[-1].flatten() for x in indices_tuple
]
tuple_indices_collect.append(indices_tuple)
embeddings_collect.append(embeddings_concat)
labels_collect.append(labels_concat)
starting_idx += len(labels_concat)
indices_tuple = concat_indices_tuple(tuple_indices_collect)
if len(indices_tuple) == 0:
return self.zero_losses()
final_embeddings = torch.cat(embeddings_collect)
final_labels = torch.cat(labels_collect)
loss = self.triplet_loss.compute_loss(
final_embeddings, final_labels, indices_tuple, ref_emb=None, ref_labels=None
)
return loss
def create_masks_train(self, class_labels):
labels_dict = defaultdict(list)
class_labels = class_labels.detach().cpu().numpy()
for idx, pid in enumerate(class_labels):
labels_dict[pid].append(idx)
unique_classes = list(labels_dict.keys())
labels_list = list(labels_dict.values())
lens_list = [len(item) for item in labels_list]
lens_list_cs = np.cumsum(lens_list)
M = max(len(instances) for instances in labels_list)
P = len(unique_classes)
query_indices = []
class_masks = torch.zeros((P, len(class_labels)), dtype=bool)
masks = torch.zeros((M * P, len(class_labels)), dtype=bool)
for class_idx, class_insts in enumerate(labels_list):
class_masks[class_idx, class_insts] = 1
for instance_idx in range(M):
matrix_idx = class_idx * M + instance_idx
if instance_idx < len(class_insts):
query_indices.append(class_insts[instance_idx])
ones = class_insts[:instance_idx] + class_insts[instance_idx + 1 :]
masks[matrix_idx, ones] = 1
else:
query_indices.append(class_insts[0])
return masks, class_masks, labels_list, query_indices
def get_all_triplets_indices(labels, ref_labels=None):
matches, diffs = get_matches_and_diffs(labels, ref_labels)
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
return torch.where(triplets)
def get_matches_and_diffs(labels, ref_labels=None):
if ref_labels is None:
ref_labels = labels
labels1 = labels.unsqueeze(1)
labels2 = ref_labels.unsqueeze(0)
matches = (labels1 == labels2).byte()
diffs = matches ^ 1
if ref_labels is labels:
matches.fill_diagonal_(0)
return matches, diffs
loss_func = CentroidTripletLoss()
l = torch.tensor([228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 83, 83, 83, 83,
83, 83, 83, 83, 83, 83, 200, 200, 200, 200, 200, 200, 200, 200,
200, 200, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 250, 250,
250, 250, 250, 250, 250, 250, 250, 250, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 124, 124, 124, 124, 124, 124, 124, 124, 124, 124,
86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 179, 179, 179, 179,
179, 179, 179, 179, 179, 179, 338, 338, 338, 338, 338, 338, 338, 338,
338, 338, 146, 146, 146, 146, 146, 146, 146, 146, 146, 146, 225, 225,
225, 225, 225, 225, 225, 225, 225, 225, 22, 22, 22, 22, 22, 22,
22, 22, 22, 22, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73,
300, 300, 300, 300, 300, 300, 300, 300, 300, 300, 431, 431, 431, 431,
431, 431, 431, 431, 431, 431, 349, 349, 349, 349, 349, 349, 349, 349,
349, 349, 305, 305, 305, 305, 305, 305, 305, 305, 305, 305, 117, 117,
117, 117, 117, 117, 117, 117, 117, 117, 56, 56, 206, 206, 206, 206,
206, 206, 206, 206, 206, 206, 1, 1, 293, 293, 293, 293, 293, 293, 293, 293,
293, 293, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 135, 135,
135, 135, 135, 135, 135, 135, 135, 135, 394, 394, 394, 394, 394, 394,
394, 394, 394, 394, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126,
366, 366, 366, 366, 366, 366, 366, 366, 366, 366, 59, 59, 59, 59,
59, 59, 59, 59, 59, 59, 235, 235, 235, 235, 235, 235, 235, 235,
235, 235, 352, 352, 352, 352, 352, 352, 352, 352, 352, 352, 193, 193,
193, 193, 193, 193, 193, 193, 193, 193, 341, 341, 341, 341, 341, 341,
341, 341, 341, 341])
a = torch.rand([l.size(0),10])
print(loss_func(a, l))
Hopefully, you can bring back this loss in the next update. Let me know if you want me to submit a PR or anything. Thanks!
@pgrosjean Sorry for the delayed response. I've been extra busy the past few weeks.
Thank you for investigating this and providing such a detailed explanation with plots!
Yes, please open a PR. Here's the old unittest file you can copy-paste for now: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/c0cc5de16134704d6b14ccf5ac866e134121e032/tests/losses/test_centroid_triplet_loss.py
In version 1.2.0, Centroid Triplet Loss is unstable. When the batch size is small, the below error always occurs. If larger batch size, it is relatively be relieved.
File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/base_metric_loss_function.py", line 38, in forward embeddings, labels, indices_tuple, ref_emb, ref_labels File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in compute_loss indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in <listcomp> indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] RuntimeError: shape '[6, -1]' is invalid for input of size 70330