Closed CSerxy closed 4 years ago
@CSerxy What about wrapping NTXent-loss with CrossBatchMemory
? This will maintain a queue of negative examples across batches, de-coupling the number of negatives from the mini-batch size.
Just some unsolicited advice, there's good reason to believe that increasing the number of negatives can actually hurt performance (see this paper). This was in fact what I found in my own project (contrastive learning for text embeddings). Of course you will have to try it for yourself to see!
Hi John, thanks for your suggestion! I will have a look at this function. However, I still hope this package can have an official implementation. :)
As for the increasing #negatives in training, it is little bit different from what I learned from Hinton's paper. As far as I know, the larger the batch size, the better the performance. Probably, the result differs in NLP.
If I understand correctly, you want to form negative pairs using the current batch and negative_samples
, but you don't want to compute gradients for negative_samples
, due to memory constraints.
CrossBatchMemory is definitely related. It keeps a queue of previous batches automatically, and the gradients are not computed for the queue. Note that if the queue contains positive samples, those will be used to form positive pairs. So it won't just be using negative samples.
If you don't want to use previous batches, but a specific list of negative_samples
, you could try this (I haven't tested it):
First create this wrapper class:
from pytorch_metric_learning.miners import BaseTupleMiner
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
class MinerWrapper(BaseTupleMiner):
# return_type should be "pos" or "neg" or "all"
def __init__(self, return_type, miner=None, **kwargs):
super().__init__(**kwargs)
assert return_type in ["pos", "neg", "all"]
self.return_type = return_type
self.miner = miner
def mine(self, embeddings, labels, ref_emb, ref_labels):
if self.miner:
indices_tuple = self.miner(embeddings, labels, ref_emb, ref_labels)
else:
indices_tuple = lmu.get_all_pairs_indices(labels, ref_labels)
a1, p, a2, n = indices_tuple
if self.return_type == "pos":
a2 = torch.LongTensor([]).to(labels.device)
n = a2.clone()
if self.return_type == "neg":
a1 = torch.LongTensor([]).to(labels.device)
p = a1.clone()
return a1, p, a2, n
Then there's 2 approaches:
import torch
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.losses import NTXentLoss
# you can optionally pass in a miner
# like MinerWrapper(return_type="neg", miner=MultiSimilarityMiner(epsilon=0.1))
neg_miner = MinerWrapper(return_type="neg")
regular_miner = MinerWrapper(return_type="all")
loss = NTXentLoss(temperature=0.10)
#### in training loop ####
# mine negatives from embeddings and negative_samples
ns_indices_tuple = neg_miner(embeddings, labels, negative_samples, negative_sample_labels)
# a shift is required because we're going to stack the embeddings
ns_indices_tuple = c_f.shift_indices_tuple(ns_indices_tuple , len(embeddings))
# mine pairs from embeddings
regular_indices_tuple = regular_miner(embeddings, labels)
# combine regular_indices_tuple and ns_indices_tuple
indices_tuple = []
for i in range(4):
x = torch.cat([regular_indices_tuple[i], ns_indices_tuple[i]], dim=0)
indices_tuple.append(x)
indices_tuple = tuple(indices_tuple)
all_embeddings = torch.cat([embeddings, negative_samples.detach()], dim=0)
all_labels = torch.cat([labels, negative_sample_labels], dim=0)
loss_value = loss(all_embeddings, all_labels, indices_tuple)
loss_value.backward()
from pytorch_metric_learning.losses import NTXentLoss, CrossBatchMemory
batch_miner = MinerWrapper(return_type="all") queue_miner = MinerWrapper(return_type="neg") inner_loss = NTXentLoss(temperature=0.10) loss = CrossBatchMemory(loss=inner_loss, embedding_size=512, memory_size=1024, miner=queue_miner)
indices_tuple = batch_miner(embeddings, labels)
loss_value = loss(embeddings, labels, indices_tuple) loss_value.backward()
Hi Kevin, thank you so much for such a detailed answer! Yes, you are right. The thing I want to do is keeping the current framework of Hinton's paper (i.e., one positive argumentation pair and all other argumentation in the batch regarded as negatives) and providing with the mini-batch with more negatives.
However, I have three questions want to confirm.
embeddings
and labels
used in your code is the corresponding representation and labels of the mini-batch. Assuming we have a mini-batch with batch size 16 and embedding size 512, then the embeddings
is a tensor with size 32*512 and labels = [16, 17, ..., 31, 0, 1, ..., 15]
.
If I have a negative sample pool with size 224 * 512
, is it true that the negative_sample_labels
used in your code should be something like negative_sample_labels = [0, 1, 2, 3, ..., 224]
? (considering you have used c_f.shift_indices_tuple to stack the embeddings)
In my case, I just want to provide the batches with more negatives and do not compute gradients for negatives. All embeddings out from the current batches are regarded as negatives. It seems both approaches you provided fit my situation, is it correct?
Many thanks and looking forward to your reply!
See my responses below.
Assuming we have a mini-batch with batch size 16 and embedding size 512, then the
embeddings
is a tensor with size 32*512 andlabels = [16, 17, ..., 31, 0, 1, ..., 15]
.
It looks like you're setting labels
to range(32)
. That means there will be no positive pairs in the batch. The labels
tensor needs to be constructed such that positive pairs have the same label. That's the only requirement. Maybe you meant to do labels = [0, 1, ... 15, 0, 1, ... 15]
If I have a negative sample pool with size
224 * 512
, is it true that thenegative_sample_labels
used in your code should be something likenegative_sample_labels = [0, 1, 2, 3, ..., 224]
? (considering you have used c_f.shift_indices_tuple to stack the embeddings)
You can set negative_sample_labels
to something like range(16, 224+16)
. That way it will have no overlap with labels
, so each element of negative_samples
will be considered a negative with respect to embeddings
.
- To my understanding, the two approaches you provided isn't conflict with each other, right?
Yes, both approaches should work.
- I wonder in CrossBatchMemory, how the previous batches' labels are given? Are they automatically assigned with negatives and updated?
CrossBatchMemory will store the labels that are given to it. In order to ensure that all elements in the queue are considered negatives, you need to construct your labels at each iteration so there is no overlap with the labels in the queue.
I think this code snippet should do everything you want (you can ignore my previous comment about creating MinerWrapper etc.):
from pytorch_metric_learning.losses import NTXentLoss, CrossBatchMemory
memory_size = 1024
inner_loss = NTXentLoss(temperature=0.10)
loss = CrossBatchMemory(loss=inner_loss, embedding_size=512, memory_size=memory_size)
#### training loop ####
for i, data in enumerate(dataloader):
embeddings = model(data)
num_pos_pairs = embeddings.size(0) // 2
# create labels that indicate what the positive pairs are
labels = torch.arange(0, num_pos_pairs)
labels = torch.cat((labels , labels))
# add an offset so that the labels do not overlap with any labels in the memory queue
labels += i*num_pos_pairs
# compute loss
# positive pairs are from embeddings
# negative pairs are from embeddings + queue
loss_value = loss(embeddings, labels)
loss_value.backward()
If you construct your labels this way, then all elements in the CrossBatchMemory queue will always be considered negatives for the current batch.
Edit: Actually this isn't exactly like MoCo, because the queue will contain old queries + keys, rather than just keys. I could add an optional input argument to CrossBatchMemory, which would be something like embeddings_for_queue
, so that you can specify which embeddings you want to go into the queue. I should also add options for selecting only negatives from the queue, in case there are positives present.
Hi Kevin,
Thank you so much!! Your code is exactly what I need. It could work!
And I just did a sanity check to compare the loss before and after adding CrossBatchMemory. I set loss = CrossBatchMemory(loss=inner_loss, embedding_size=512, memory_size=batch_size)
so that this loss is actually the same as NTXentLoss we used right? Ideally, the result should be same.
However, I found the two losses are different. The NTXentLoss = 0.17
and CrossBatchMemory loss = 0.11
, is this normal?
That's a good sanity check! I think I know the problem. CrossBatchMemory is including self-comparisons as positive pairs (i.e. an embedding compared to itself is counting as a positive pair). I'll try to fix this tomorrow and I'll add your sanity check to the unit tests.
Excellent! Thanks!
It should be fixed now. Download v0.9.89.dev2:
pip install pytorch-metric-learning==0.9.89.dev2
Let me know if it works for you
Hi Kevin,
I really appreciate your hard work! Yes, the current version fix the initial loss unmatch problem. However, when I launched two versions, I found that the loss with CrossBatchMemory wrapper get stacked at some point and stop updating.
I run the same sanity check experiment as in my last comment. Both losses start with loss 0.17 and get stacked at 0.155 for some time. But NTXentLoss restarts lowering down at epoch 3 (stops roughly for 2 epochs, which is normal), the CrossBatchMemory gets stacked always. After many epochs, NTXentLoss can be lowered to 0.001.
I wonder if you know how this could happen? Thanks a lot!
Let (a1,p) represent the positive pairs and (a2,n) represent the negative pairs.
Closing for now, please re-open if you have more questions.
Thanks, Kevin! I think the problem I mentioned is caused by some other bugs, not related to your package. I really appreciate your help!
Hi Kevin,
I wonder if such an extended NT-Xent loss could be implemented?
The NT-Xent implemented in this package can return the pairwise loss when given a mini-batch and a label array. I wonder if for the purpose of increasing negative samples to make the task harder, could we directly use this package?
To be more specific, I use @JohnGiorgi's example:
import torch
from pytorch_metric_learning.losses import NTXentLoss
batch_size = 16
embedding_dim = 512
anchor_embeddings = torch.randn(batch_size, embedding_dim)
positive_embeddings = torch.randn(batch_size, embedding_dim)
embeddings = torch.cat((anchor_embeddings, positive_embeddings))
indices = torch.arange(0, anchor_embeddings.size(0), device=anchor_embeddings.device)
labels = torch.cat((indices, indices))
loss = NTXentLoss(temperature=0.10)
loss(embeddings, labels)
Assuming I have another list of negative samples with size 224 * 512, how could I use the package? I would be really appreciate if you could provide this function since this would be really useful for making the CL harder when limited by the resource.