sthalles / SimCLR

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
https://sthalles.github.io/simple-self-supervised-learning/
MIT License
2.19k stars 457 forks source link

Info NCE loss #33

Open fjbriones opened 3 years ago

fjbriones commented 3 years ago

Hi, may I ask how you were able to calculate the info nce loss in this work? I am confused on the methodology as it is quite different from the code of the authors.

You are returning labels of all 0 because you only want to calculate negative labels. However in this code here, you used the logits for both the negative samples and the positive sample (I'm assuming this is the augmented counterpart of the image). May I ask the reasoning for this kind of implementation?

https://github.com/sthalles/SimCLR/blob/1848fc934ad844ae630e6c452300433fe99acfd9/simclr.py#L51-L55

P.S.: I am still at loss currently on how you were able to simplify the code to just calculating only the negative samples. Hopefully this can be clarified in your reply. Thank you!

ha-lins commented 3 years ago

Same question! I think the code should be revised to calculate both negative & positive samples.

fjbriones commented 3 years ago

I think I sort of get it now. I think the zero array labels is meant to indicate the positive label for one pass. Since at Line 51, they put the positive features in the front of the tensor, this is meant to be the positive label which is going to be index 0 which are the labels of all members of the batch.

fjbriones commented 3 years ago
random_labels = torch.randint(low=0, high=logits.shape[1], size=(logits.shape[0],1)).to(device)
index = torch.arange(logits.shape[0]).to(device).unsqueeze(1)
labels_access = torch.cat([index, random_labels], 1)
labels_access = torch.transpose(labels_access, 0, 1)
temp = logits[tuple(labels_access)]
logits[:,0] = temp
logits[tuple(labels_access)] = positives.squeeze()`

logits = logits/temperature
return logits, random_labels.squeeze()

This is a possible solution to randomly place the positive labels so that the target output for the network will not always be 0. Though it seems to work in their case and some other experiments I made so I guess the implemented one is fine.

zhao1iang commented 2 years ago

https://github.com/sthalles/SimCLR/issues/16 a nice explanation. In a nutshell, the first colume is the logits of positive sample, so assign labels to all zero.

jmarrietar commented 2 years ago

Basically, the implementation puts the first column in the logits as the positive instances ( that's why the label 0 for all cases).

RongfanLi98 commented 2 years ago

Denote N = batch * n_views, and for example 64. Logits are [64, 63] tensor, and labels are [64] tensor. Note the dimension.

The logits are 64 samples' similarity to the other 63 samples and the positive ones are in the first column, which means the right class is 0. Also, you can place the positive ones at the 63rd column, and the right class is therefore 62.

The labels are all 0 because all 64 samples have the same positive pair in the first column. Also, you can place the positive ones at the 63rd column, and the right labels are therefore all 62.

Aaaattack commented 2 years ago

I know why the labels are always zero. But, if the n-views is not 2, but 3 or 4 for example, then the positive samples are not always in colume 0. (0, 1) for the case that n-views equals to 3, and (0, 1, 2) for the case that n-views equals to 4.

muyuuuu commented 2 years ago

torch.nn.CrossEntryLoss() = LogSoftmax + NLLLoss, you should see details of NLLLoss.

here101 commented 2 years ago

I know why the labels are always zero. But, if the n-views is not 2, but 3 or 4 for example, then the positive samples are not always in colume 0. (0, 1) for the case that n-views equals to 3, and (0, 1, 2) for the case that n-views equals to 4.

If I want to amend the code with 3 or 4-views, how I can amend the code.Please give me some tips if you are free. Thanks in advance!

Aaaattack commented 2 years ago

I know why the labels are always zero. But, if the n-views is not 2, but 3 or 4 for example, then the positive samples are not always in colume 0. (0, 1) for the case that n-views equals to 3, and (0, 1, 2) for the case that n-views equals to 4.

If I want to amend the code with 3 or 4-views, how I can amend the code.Please give me some tips if you are free. Thanks in advance!

My solution is simple. I didn't modify the info_nce_loss, but several ce-loss calculations. Since the CrossEntropyLoss is based on NLLLoss, the labels with all the same values means a column in the logit.

For example, if the n-views is 4, then the first 3 columns in the logits are positives, so there are 3 ce-loss calculations, labels with all 0 means the first column, labels with all 1 means the second column, and labels with all 2 means the third column.

rainbow-xiao commented 1 year ago

assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2." can be found in run.py, this works for 2_views, if we wanna try more, we can do some self-modifications absolutely.