KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.95k stars 657 forks source link

VICReg requires torch >= 1.10 #467

Closed PDomonkos closed 2 years ago

PDomonkos commented 2 years ago

Hi,

The torch.cov in the covariance_loss function requires torch >= 1.10. Replacing:

 _, D = emb.size()
 cov_emb = torch.cov(emb.T)
 cov_ref_emb = torch.cov(ref_emb.T)

to: batch_size, D = emb.size() cov_emb = (emb.T @ emb) / (batch_size - 1) cov_ref_emb = (ref_emb.T @ ref_emb) / (batch_size - 1) would make it compatible with older versions as well.

Another remark is that in the reference implementation, the variance loss is divided by 2: https://github.com/facebookresearch/vicreg/blob/5e7b38f4586384bbb0d9a035352fab1d8f03b3b4/main_vicreg.py#L207 Was it intentionally left out, or am I missing something here?

KevinMusgrave commented 2 years ago

Thanks for the tip regarding pytorch version.

Re: division by 2, I don't think it was left out intentionally. I'll take a closer look later.

cwkeam commented 2 years ago

@KevinMusgrave Hi! Actually I'll take a look at this since I contributed to this loss function. It's a bit late in my time zone right now so I'll take a look tomorrow.

KevinMusgrave commented 2 years ago

Thanks @cwkeam!

KevinMusgrave commented 2 years ago

Fixed in v1.3.1