KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.99k stars 658 forks source link

Multi-gpu training #10

Closed JohnGiorgi closed 4 years ago

JohnGiorgi commented 4 years ago

Hi, I am stuck on how multi-GPU training would work for loss functions with more than one negative, particularly NTXentLoss.

In SimCLR, the number of negatives per some positive pair is taken to be 2 * (N - 1) (all examples in a minibatch of size N that don't belong to that positive pair), and they find (as other works before them) that the bigger the batch size, the larger the number of negatives, and the better the learned representations.

DataParallel and DistributedDataParallel divide up a mini-batch, send each partition of examples to a GPU, compute gradients, and then average these gradients before backpropagating. But this means that each GPU is computing a loss with N/n_gpus examples and, therefore, 2 * (N/n_gpus - 1) negatives per positive pair.

My question is: how might I use a loss function from this library, such as NTXentLoss, with DataParallel or DistributedDataParallel that avoids this "issue"? I.e. that allows me to use multiple GPUs while maintaining 2 * (N - 1) negatives per positive pair.

KevinMusgrave commented 4 years ago

DataParallel should work fine, because it's still just 1 process running. This is what I always use.

I haven't used DistributedDataParallel. I think you're right that each process will only see part of the batch, and therefore the number of negative pairs will be wrong. At the moment I can't think of an easy way to fix that. Maybe the loss function in one process needs to get embeddings from all the other processes?

JohnGiorgi commented 4 years ago

Hmm, okay. I will have to look into DataParallel more carefully. The PyTorch documentation suggests it is splitting up a batch across devices and computing gradients independently on each device, so I would suspect that the problem I mentioned above still stands:

This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension (other objects will be copied once per device). In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module

Is CrossBatchMemory a potential solution here?

KevinMusgrave commented 4 years ago

Yes, if you wrap your model with DataParallel, then when computing the output, the batch will be split across devices. But if you do not wrap the loss function with DataParallel, then the loss function will have access to the entire batch. So it will be able to compute all negative pairs.

In contrast, I think DistributedDataParallel will launch multiple separate processes. So the loss function in each process will not have access to the entire batch by default.

KevinMusgrave commented 4 years ago

CrossBatchMemory stores embeddings from previous iterations and creates pairs using the current iteration's embeddings combined with those previous embeddings. So it works with the assumption that embeddings drift slowly as training progresses. You can try it, but it will be very different from the SimCLR paper. Also, it's really slow right now and I need to fix that.

edit I suppose you could use it in your scenario by passing in chunks of the current batch. Still unnecessary in my opinion as DataParallel works.

KevinMusgrave commented 4 years ago

Here's something you can check.

from pytorch_metric_learning import losses, miners
# set margins to ensure that no pairs are left out for this example
mining_func = miners.PairMarginMiner(pos_margin=0, neg_margin=100, use_similarity=False)

# in your training loop:
pairs = mining_func(embeddings, labels)
print(mining_func.num_pos_pairs)
print(mining_func.num_neg_pairs)

This will print the number of positive and negative pairs "mined", but really it's just the total number of positive and negative pairs because the pos_margin and neg_margin are set to extreme values. As long as you don't wrap the mining function with DataParallel then it should print the correct number. You can still wrap your model with DataParallel though.

JohnGiorgi commented 4 years ago

Trying to wrap my head around this example.

I generated a dummy batch of 32 positive pairs, each of dim 128. The printout gives me the expected number of positive pairs (32) but an unexpected number of negative pairs (960). I would have expected 2(batch_size - 1) = 2(32 - 1) = 62.

Any idea what I am misunderstanding?

import torch
from pytorch_metric_learning import losses, miners

batch_size = 16
embedding_dim = 128

# generate a dummy batch
anchor_embeddings = torch.randn(batch_size, embedding_dim)
positive_embeddings = torch.randn(batch_size, embedding_dim)
embeddings = torch.cat((anchor_embeddings, positive_embeddings),)
indices = torch.arange(0, anchor_embeddings.size(0), device=anchor_embeddings.device)
labels = torch.cat((indices, indices))

# set margins to ensure that no pairs are left out for this example
mining_func = miners.PairMarginMiner(pos_margin=0, neg_margin=100, use_similarity=False)

# in your training loop:
pairs = mining_func(embeddings, labels)
print(mining_func.num_pos_pairs)  #=> 32
print(mining_func.num_neg_pairs)  #=> 960
KevinMusgrave commented 4 years ago

@JohnGiorgi There are 32 samples. For each sample, 1/31 are positives, and 30/31 are negatives. So 30*32 = 960

JohnGiorgi commented 4 years ago

Right. Thank you!

JohnGiorgi commented 4 years ago

Okay, I think I have figured out how this works with DistributedDataParallel. Writing it up here in case anyone else stumbles on this problem, as it is non-trivial to solve:

You need to:

  1. all_gather embeddings from all replicas.
  2. Because gathered tensors have no gradients, we overwrite the gathered embeddings tensor from the current replica with the embeddings tensor produced on that replica, which has gradients to the encoder.
  3. Concatenate the list of embeddings before computing the loss.

Some almost functional pseudo-code:

import torch
import torch.distributed as dist

# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)

# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)

# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings

# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)

# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)

In my application, this appears to work. The trick you mention to print out the num_neg_pairs from above prints the expected value.

cbaziotis commented 3 years ago

@JohnGiorgi Thanks for sharing your solution. I was wondering how your approach differs from other approaches that subclass torch.autograd.Function, such as:

  1. https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py that uses this Function.
  2. https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py#L224 that uses this Function.

Is there a practical difference or is it just a matter of style?

Thanks!

Update: This blog post claims that using your approach, we should re-scale the gradients by the number of word_size to obtain the correct gradients.

weiyx16 commented 3 years ago

@JohnGiorgi Thanks for sharing your solution. I was wondering how your approach differs from other approaches that subclass torch.autograd.Function, such as:

  1. https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py that uses this Function.
  2. https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py#L224 that uses this Function.

Is there a practical difference or is it just a matter of style?

Thanks!

Update: This blog post claims that using your approach, we should re-scale the gradients by the number of word_size to obtain the correct gradients.

I have taken an experiment on a small toy example and compared the grad between two methods, and I think these two ways create same gradients.

chenxi52 commented 2 years ago

Do you think this version --- https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py should re-scale the gradients ??? if its parameters works in a single gpu too( i think so