HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
2.98k stars 525 forks source link

loss computation: mean and not sum #135

Open CassNot opened 11 months ago

CassNot commented 11 months ago

Dear authors,

Thank you for your code!

We had a question concerning the loss implementation. We saw that for each minibatch, the mean is computed and not the sum as in the paper (https://arxiv.org/pdf/2004.11362.pdf - equation 2): https://github.com/HobbitLong/SupContrast/blob/331aab5921c1def3395918c05b214320bef54815/losses.py#L96

We were wondering if there was a reason for this choice.

Thank you

HobbitLong commented 10 months ago

Good catch! I think the eq 2 in the paper has ignored the 1/(2N).

dave4422 commented 8 months ago

Hi,

I've been reviewing the implementation, and I noticed the line loss = loss.view(anchor_count, batch_size).mean(). Given the computations, it seems that the result would be equivalent to simply using loss.mean(). Could you kindly explain the rationale behind the reshaping here?

dave4422 commented 8 months ago

I assume it's just for readability?

HobbitLong commented 8 months ago

yeah, it's just helping understand the shape (potentially may help understand what's going on).