facebookresearch / swav

PyTorch implementation of SwAV https//arxiv.org/abs/2006.09882
Other
2.01k stars 280 forks source link

Surprising but interesting duplicated clusters #71

Closed Jiawei-Yang closed 3 years ago

Jiawei-Yang commented 3 years ago

Hi, thanks for your brilliant work!

I have found an interesting fact that many of the learned prototypes are duplicated.

I began with the question "how well do all the learned prototypes scatter?". So I downloaded the best pre-trained model from this link.

Interestingly, when I computed the pairwise cosine similarity among all prototypes, multiple pairs of prototypes have similarity scores of 1's and turn out to be the same.

Here is the code.

import torch
import torch.nn.functional as F

model = torch.load('swav_800ep_pretrain.pth.tar', map_location='cpu')
protos = model['module.prototypes.weight'] # (3000,128)
similarity = protos @ protos.T  # (3000,3000)
non_diag = similarity - torch.eye(3000) # This matrix should contain pairwise cosine similarity of non-identical clusters.

# Take the first prototype as an example
values, indices = non_diag[0].sort()

>>> print(values[-10:])
tensor([0.5267, 0.5267, 0.5267, 0.6343, 0.6343, 0.6343, 0.6344, 1.0000, 1.0000,
        1.0000])
>>> print(indices[10:])
tensor([  91,  932, 1244,  ...,  937, 1819, 2363])
>>> print(protos[0])
tensor([-0.0873, -0.0289,  0.1113,  0.1079,  0.0845, -0.0683,  0.0359, -0.0891,
         0.1160,  0.0086,  0.0602, -0.0444, -0.0620, -0.0612, -0.1079, -0.0714,
        -0.1299,  0.0790, -0.0428,  0.0628,  0.0202,  0.0361,  0.0414,  0.1667,
        -0.1552, -0.0179,  0.1873,  0.1460,  0.1022,  0.0320,  0.0937, -0.0669,
        -0.0611, -0.0737,  0.0175, -0.1226, -0.1596, -0.0358, -0.1683,  0.0984,
        -0.0613, -0.1481, -0.0617,  0.0235, -0.0650,  0.0165,  0.0387, -0.0046,
         0.0452,  0.1192,  0.0245, -0.0401,  0.0379, -0.1750,  0.0459, -0.0631,
        -0.0551, -0.1220, -0.0352,  0.0178,  0.1306, -0.1511,  0.0077,  0.0274,
         0.0032, -0.0396, -0.1273,  0.0903, -0.1012,  0.0024,  0.1248,  0.1845,
         0.0089, -0.0679, -0.0545,  0.0511,  0.0709,  0.1274, -0.0679, -0.0050,
         0.0006, -0.1155,  0.1040,  0.0527, -0.1587,  0.1085,  0.0560,  0.0032,
        -0.0046, -0.0338, -0.0009, -0.0129, -0.0033,  0.0171,  0.0436, -0.0369,
         0.0274,  0.0577, -0.0145, -0.0775, -0.0514,  0.0060, -0.0040,  0.0495,
        -0.0725,  0.0900, -0.0259,  0.1121, -0.0870, -0.0796,  0.2087, -0.0425,
        -0.0596, -0.0466,  0.0146,  0.0791, -0.0211,  0.1539, -0.1551, -0.0358,
        -0.1216, -0.1700,  0.0100,  0.1275,  0.0419,  0.2577, -0.0983,  0.0249])
>>> print(protos[2363])
tensor([-0.0873, -0.0289,  0.1113,  0.1079,  0.0845, -0.0683,  0.0359, -0.0891,
         0.1160,  0.0086,  0.0602, -0.0444, -0.0620, -0.0612, -0.1079, -0.0714,
        -0.1299,  0.0790, -0.0428,  0.0628,  0.0202,  0.0361,  0.0414,  0.1667,
        -0.1552, -0.0179,  0.1873,  0.1460,  0.1022,  0.0320,  0.0937, -0.0669,
        -0.0611, -0.0737,  0.0175, -0.1226, -0.1596, -0.0358, -0.1683,  0.0984,
        -0.0613, -0.1481, -0.0617,  0.0235, -0.0650,  0.0165,  0.0387, -0.0046,
         0.0452,  0.1192,  0.0245, -0.0401,  0.0379, -0.1750,  0.0459, -0.0631,
        -0.0551, -0.1220, -0.0352,  0.0178,  0.1306, -0.1511,  0.0077,  0.0274,
         0.0032, -0.0396, -0.1273,  0.0902, -0.1012,  0.0024,  0.1248,  0.1845,
         0.0089, -0.0679, -0.0545,  0.0511,  0.0709,  0.1274, -0.0679, -0.0050,
         0.0006, -0.1155,  0.1040,  0.0527, -0.1587,  0.1085,  0.0560,  0.0032,
        -0.0046, -0.0338, -0.0009, -0.0129, -0.0033,  0.0171,  0.0436, -0.0369,
         0.0274,  0.0577, -0.0145, -0.0775, -0.0514,  0.0060, -0.0040,  0.0495,
        -0.0725,  0.0900, -0.0259,  0.1121, -0.0870, -0.0796,  0.2087, -0.0425,
        -0.0596, -0.0466,  0.0146,  0.0791, -0.0211,  0.1539, -0.1551, -0.0358,
        -0.1216, -0.1700,  0.0100,  0.1275,  0.0419,  0.2577, -0.0983,  0.0249])

So, prototypes #0, #937, #1819, #2363 are in fact identical.

I wonder if you had noticed this fact, and do you have any idea about why this is the case?

Best, Jiawei

mathildecaron31 commented 3 years ago

Hi @Jiawei-Yang, Thanks for your kind words and for sharing this finding ! I have not noticed that, I think it might explained why using more clusters have not a big impact on the performance. Feel free to post more if you have other findings/analysis :).