KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.98k stars 658 forks source link

How to use a queue of negative samples as done in MoCo #138

Closed CSerxy closed 4 years ago

CSerxy commented 4 years ago

Hi Kevin,

I wonder if such an extended NT-Xent loss could be implemented?

The NT-Xent implemented in this package can return the pairwise loss when given a mini-batch and a label array. I wonder if for the purpose of increasing negative samples to make the task harder, could we directly use this package?

To be more specific, I use @JohnGiorgi's example: import torch from pytorch_metric_learning.losses import NTXentLoss

batch_size = 16 embedding_dim = 512

anchor_embeddings = torch.randn(batch_size, embedding_dim) positive_embeddings = torch.randn(batch_size, embedding_dim)

embeddings = torch.cat((anchor_embeddings, positive_embeddings)) indices = torch.arange(0, anchor_embeddings.size(0), device=anchor_embeddings.device) labels = torch.cat((indices, indices))

loss = NTXentLoss(temperature=0.10) loss(embeddings, labels)

Assuming I have another list of negative samples with size 224 * 512, how could I use the package? I would be really appreciate if you could provide this function since this would be really useful for making the CL harder when limited by the resource.

JohnGiorgi commented 4 years ago

@CSerxy What about wrapping NTXent-loss with CrossBatchMemory? This will maintain a queue of negative examples across batches, de-coupling the number of negatives from the mini-batch size.

Just some unsolicited advice, there's good reason to believe that increasing the number of negatives can actually hurt performance (see this paper). This was in fact what I found in my own project (contrastive learning for text embeddings). Of course you will have to try it for yourself to see!

CSerxy commented 4 years ago

Hi John, thanks for your suggestion! I will have a look at this function. However, I still hope this package can have an official implementation. :)

As for the increasing #negatives in training, it is little bit different from what I learned from Hinton's paper. As far as I know, the larger the batch size, the better the performance. Probably, the result differs in NLP.

KevinMusgrave commented 4 years ago

If I understand correctly, you want to form negative pairs using the current batch and negative_samples, but you don't want to compute gradients for negative_samples, due to memory constraints.

CrossBatchMemory is definitely related. It keeps a queue of previous batches automatically, and the gradients are not computed for the queue. Note that if the queue contains positive samples, those will be used to form positive pairs. So it won't just be using negative samples.

If you don't want to use previous batches, but a specific list of negative_samples, you could try this (I haven't tested it):

First create this wrapper class:

from pytorch_metric_learning.miners import BaseTupleMiner
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu

class MinerWrapper(BaseTupleMiner):

    # return_type should be "pos" or "neg" or "all"
    def __init__(self, return_type, miner=None, **kwargs):
        super().__init__(**kwargs)
        assert return_type in ["pos", "neg", "all"]
        self.return_type = return_type
        self.miner = miner

    def mine(self, embeddings, labels, ref_emb, ref_labels):
        if self.miner:
            indices_tuple = self.miner(embeddings, labels, ref_emb, ref_labels)
        else:
            indices_tuple = lmu.get_all_pairs_indices(labels, ref_labels)

        a1, p, a2, n = indices_tuple

        if self.return_type == "pos":
            a2 = torch.LongTensor([]).to(labels.device)
            n = a2.clone()       
        if self.return_type == "neg":
            a1 = torch.LongTensor([]).to(labels.device)
            p = a1.clone()

        return a1, p, a2, n

Then there's 2 approaches:

  1. If you want to use a specific set of negative samples:
import torch
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.losses import NTXentLoss

# you can optionally pass in a miner
# like MinerWrapper(return_type="neg", miner=MultiSimilarityMiner(epsilon=0.1))
neg_miner = MinerWrapper(return_type="neg")
regular_miner = MinerWrapper(return_type="all")
loss = NTXentLoss(temperature=0.10)

#### in training loop ####
# mine negatives from embeddings and negative_samples
ns_indices_tuple = neg_miner(embeddings, labels, negative_samples, negative_sample_labels)
# a shift is required because we're going to stack the embeddings
ns_indices_tuple = c_f.shift_indices_tuple(ns_indices_tuple , len(embeddings))

# mine pairs from embeddings
regular_indices_tuple = regular_miner(embeddings, labels)

# combine regular_indices_tuple and ns_indices_tuple
indices_tuple = []
for i in range(4):
    x = torch.cat([regular_indices_tuple[i], ns_indices_tuple[i]], dim=0)
    indices_tuple.append(x)
indices_tuple = tuple(indices_tuple)

all_embeddings = torch.cat([embeddings, negative_samples.detach()], dim=0)
all_labels = torch.cat([labels, negative_sample_labels], dim=0)

loss_value = loss(all_embeddings, all_labels, indices_tuple)
loss_value.backward()
  1. Alternatively, you can use MinerWrapper with CrossBatchMemory, so that only negative pairs will be mined from the queue. The queue consists of previous batches of embeddings, and is updated every time you compute a new loss value:
    
    from pytorch_metric_learning.losses import NTXentLoss, CrossBatchMemory

batch_miner = MinerWrapper(return_type="all") queue_miner = MinerWrapper(return_type="neg") inner_loss = NTXentLoss(temperature=0.10) loss = CrossBatchMemory(loss=inner_loss, embedding_size=512, memory_size=1024, miner=queue_miner)

in training loop

indices_tuple = batch_miner(embeddings, labels)

positive pairs are formed using batch_miner

negative pairs come from both batch_miner and queue_miner

loss_value = loss(embeddings, labels, indices_tuple) loss_value.backward()

CSerxy commented 4 years ago

Hi Kevin, thank you so much for such a detailed answer! Yes, you are right. The thing I want to do is keeping the current framework of Hinton's paper (i.e., one positive argumentation pair and all other argumentation in the batch regarded as negatives) and providing with the mini-batch with more negatives.

However, I have three questions want to confirm.

  1. The embeddings and labels used in your code is the corresponding representation and labels of the mini-batch.

Assuming we have a mini-batch with batch size 16 and embedding size 512, then the embeddings is a tensor with size 32*512 and labels = [16, 17, ..., 31, 0, 1, ..., 15].

If I have a negative sample pool with size 224 * 512, is it true that the negative_sample_labels used in your code should be something like negative_sample_labels = [0, 1, 2, 3, ..., 224]? (considering you have used c_f.shift_indices_tuple to stack the embeddings)

  1. To my understanding, the two approaches you provided isn't conflict with each other, right?

In my case, I just want to provide the batches with more negatives and do not compute gradients for negatives. All embeddings out from the current batches are regarded as negatives. It seems both approaches you provided fit my situation, is it correct?

  1. I wonder in CrossBatchMemory, how the previous batches' labels are given? Are they automatically assigned with negatives and updated?

Many thanks and looking forward to your reply!

KevinMusgrave commented 4 years ago

See my responses below.

Assuming we have a mini-batch with batch size 16 and embedding size 512, then the embeddings is a tensor with size 32*512 and labels = [16, 17, ..., 31, 0, 1, ..., 15].

It looks like you're setting labels to range(32). That means there will be no positive pairs in the batch. The labels tensor needs to be constructed such that positive pairs have the same label. That's the only requirement. Maybe you meant to do labels = [0, 1, ... 15, 0, 1, ... 15]

If I have a negative sample pool with size 224 * 512, is it true that the negative_sample_labels used in your code should be something like negative_sample_labels = [0, 1, 2, 3, ..., 224]? (considering you have used c_f.shift_indices_tuple to stack the embeddings)

You can set negative_sample_labels to something like range(16, 224+16). That way it will have no overlap with labels, so each element of negative_samples will be considered a negative with respect to embeddings.

  1. To my understanding, the two approaches you provided isn't conflict with each other, right?

Yes, both approaches should work.

  1. I wonder in CrossBatchMemory, how the previous batches' labels are given? Are they automatically assigned with negatives and updated?

CrossBatchMemory will store the labels that are given to it. In order to ensure that all elements in the queue are considered negatives, you need to construct your labels at each iteration so there is no overlap with the labels in the queue.

I think this code snippet should do everything you want (you can ignore my previous comment about creating MinerWrapper etc.):

from pytorch_metric_learning.losses import NTXentLoss, CrossBatchMemory

memory_size = 1024
inner_loss = NTXentLoss(temperature=0.10) 
loss = CrossBatchMemory(loss=inner_loss, embedding_size=512, memory_size=memory_size)

#### training loop ####
for i, data in enumerate(dataloader):
    embeddings = model(data)
    num_pos_pairs = embeddings.size(0) // 2

    # create labels that indicate what the positive pairs are
    labels = torch.arange(0, num_pos_pairs)
    labels = torch.cat((labels , labels))

    # add an offset so that the labels do not overlap with any labels in the memory queue
    labels += i*num_pos_pairs

    # compute loss
    # positive pairs are from embeddings 
    # negative pairs are from embeddings + queue
    loss_value = loss(embeddings, labels)
    loss_value.backward()

If you construct your labels this way, then all elements in the CrossBatchMemory queue will always be considered negatives for the current batch.

Edit: Actually this isn't exactly like MoCo, because the queue will contain old queries + keys, rather than just keys. I could add an optional input argument to CrossBatchMemory, which would be something like embeddings_for_queue, so that you can specify which embeddings you want to go into the queue. I should also add options for selecting only negatives from the queue, in case there are positives present.

CSerxy commented 4 years ago

Hi Kevin,

Thank you so much!! Your code is exactly what I need. It could work!

And I just did a sanity check to compare the loss before and after adding CrossBatchMemory. I set loss = CrossBatchMemory(loss=inner_loss, embedding_size=512, memory_size=batch_size) so that this loss is actually the same as NTXentLoss we used right? Ideally, the result should be same.

However, I found the two losses are different. The NTXentLoss = 0.17 and CrossBatchMemory loss = 0.11, is this normal?

KevinMusgrave commented 4 years ago

That's a good sanity check! I think I know the problem. CrossBatchMemory is including self-comparisons as positive pairs (i.e. an embedding compared to itself is counting as a positive pair). I'll try to fix this tomorrow and I'll add your sanity check to the unit tests.

CSerxy commented 4 years ago

Excellent! Thanks!

KevinMusgrave commented 4 years ago

It should be fixed now. Download v0.9.89.dev2:

pip install pytorch-metric-learning==0.9.89.dev2

Let me know if it works for you

CSerxy commented 4 years ago

Hi Kevin,

I really appreciate your hard work! Yes, the current version fix the initial loss unmatch problem. However, when I launched two versions, I found that the loss with CrossBatchMemory wrapper get stacked at some point and stop updating.

I run the same sanity check experiment as in my last comment. Both losses start with loss 0.17 and get stacked at 0.155 for some time. But NTXentLoss restarts lowering down at epoch 3 (stops roughly for 2 epochs, which is normal), the CrossBatchMemory gets stacked always. After many epochs, NTXentLoss can be lowered to 0.001.

I wonder if you know how this could happen? Thanks a lot!

KevinMusgrave commented 4 years ago

Let (a1,p) represent the positive pairs and (a2,n) represent the negative pairs.

KevinMusgrave commented 4 years ago

Closing for now, please re-open if you have more questions.

CSerxy commented 4 years ago

Thanks, Kevin! I think the problem I mentioned is caused by some other bugs, not related to your package. I really appreciate your help!