sthalles / SimCLR

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
https://sthalles.github.io/simple-self-supervised-learning/
MIT License
2.19k stars 457 forks source link

Issue with batch-size #40

Closed Mayurji closed 1 year ago

Mayurji commented 2 years ago

In function info_nce_loss, the line 28, creates labels based on batch_size and on other side we have STL10 dataset which has 100,000 images which is divisible by batch_size of 32 and having batch_size like 128 or 64 gives a remainder of 32.

Having batch_size != 32, causes error in line 42, because the similarity matrix will based on features and labels will be based on batch size.

For instance, if the batch size = 128, the remaining images in the dataset in the last iter of data_loader is 32. Since we create two variant of each image we'll have 64 images. Now we have 128 x 2 = 256 labels from line 28, and we'll have similarity matrix of (64 x 128, 128 x 64) => (64 x 64) but with mask (256 x 256) causing "dimension mismatch"

Solution: Change Line 28 as below

labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)

image

pengzhangzhi commented 2 years ago

such a good solution!

XY-boy commented 1 year ago

you help me a lot!

rainbow-xiao commented 1 year ago

One correction, labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)--> labels = torch.cat([torch.arange(features.shape[0]//self.args.n_views) for i in range(self.args.n_views)], dim=0)