Open RylanSchaeffer opened 1 year ago
This looks interesting, will have a closer look later :)
Sounds good!
The loss is quite straightforward. Suppose x
has shape (batch size, number of views/transformations e.g., 2, 4, 32, projection output dimension)
. MMCR proposes the following loss:
avg_over_views = torch.mean(x, dim=1)
loss = - torch.linalg.matrix_norm(avg_over_views, ord="nuc")
Arxiv link: https://arxiv.org/abs/2303.03307
I'm working on implementing this now.
Remaining issues that we have to do before this issue is complete:
This is a NeurIPS 2023 paper with a new SSL vision method: https://neurips.cc/virtual/2023/poster/70447
It'd be great if Lightly could add it.