mdiephuis / SimCLR

Pytorch implementation of "A Simple Framework for Contrastive Learning of Visual Representations"
MIT License
79 stars 12 forks source link

About gradient accumulation #13

Open wujunjie1998 opened 3 years ago

wujunjie1998 commented 3 years ago

Hi:

Thanks for your implementation. I just have a question regarding to the gradient accumulation part of NT-Xent loss. Though we divide the loss by num_accumulation_steps at each mini_batch, the following equation: loss = torch.mean(-torch.log(sim_match / (torch.sum(sim_mat, dim=-1) - norm_sum))) will still let the loss not comparable, since "torch.sum(sim_mat, dim=-1) - norm_sum)" is performing on a matrix with shape [2Batch_size, 2Batch_size]. For example, when we are running with "Batch_size 256, accumulation step 1" and "Batch_size 64, accumulation step 4", their loss values are not similar.

Any comments about this?

Thanks

AlekseySh commented 3 years ago

IMHO, this loss is not additive by design (like f1 score)

AlekseySh commented 3 years ago

So, features should be collected over the virtual batches and then loss should be applied