facebookresearch / barlowtwins

PyTorch implementation of Barlow Twins.
MIT License
967 stars 128 forks source link

Why do we average out correlation matrices from different GPUs? Is this mathematically valid? #33

Closed radekd91 closed 3 years ago

radekd91 commented 3 years ago

Thanks for this great work!

I am a bit confused about the computation of the Barlow Twins loss in the multi-gpu setting. If I understand it correctly, each batch is split into smaller minibatches and these are then processed on separage GPUs. Each GPU computes the cross correlation matrix corresponding to its minibatch. The cross correlations between samples on different GPUs are not computed. It is not clear to me, why the different cross correlation matrices are averaged out across GPUs. This creates a mean correlation matrix and this one is then used for loss computation.

Why not compute the loss for each correlation matrix separately and only average out the final loss? Or even better, why not compute the full cross correlation matrix (i.e. gather all embedding vectors onto one device and computing the cross correlation there?)

I fail to see why summing up correlation matrices is a valid mathematical operation - or is it just an implementation "hack" that makes things easier? I guess since all cross correlation matrices are ideally converging towards identity matrices (as forced by the loss function), avereging them out does not strictly break the convergence - is that the case?

I am not very experienced with distributed deep learning so there may be technical things I don't understand. Thanks for your help.

https://github.com/facebookresearch/barlowtwins/blob/a655214c76c97d0150277b85d16e69328ea52fd9/main.py#L206-L223

jzbontar commented 3 years ago

Or even better, why not compute the full cross correlation matrix (i.e. gather all embedding vectors onto one device and computing the cross correlation there?)

Summing cross correlation matrices (like we do in our code) is equivalent to computing the full cross correlation matrix by gathering all embedding vectors onto one device. They give you exactly the same result.

Think about how to distribute a dot product operation across n machines (computing the cross correlation matrix is basically just a bunch of dot products, one for each pair of features). You could split the vectors into n chunks, compute n smaller dot products (one dot product for each of the n chunks) and sum them to get the final result. Or if you prefer code:

>>> import torch
>>> x = torch.Tensor(8).normal_()
>>> y = torch.Tensor(8).normal_()
>>> torch.allclose(x[:4] @ y[:4] + x[4:] @ y[4:], x @ y)
True
radekd91 commented 3 years ago

Thanks for the explanation. It's clear to me, now.

ltnghia commented 2 years ago

Hi, how do we do this on a single GPU? Because torch.distributed seems not to work on a single GPU? I use torch.nn.DataParallel on to run the code on a single GPU.

# empirical cross-correlation matrix 
 c = self.bn(z1).T @ self.bn(z2) 

 # sum the cross-correlation matrix between all gpus 
 c.div_(self.args.batch_size) 
 torch.distributed.all_reduce(c) 

 on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() 
 off_diag = off_diagonal(c).pow_(2).sum() 
 loss = on_diag + self.args.lambd * off_diag 
TQi-Yang commented 2 years ago

Hi @ltnghia . Is the problem solved? I think the following code may not be needed, right?:

sum the cross-correlation matrix between all gpus

c.div_(self.args.batch_size) torch.distributed.all_reduce(c)