HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
2.98k stars 525 forks source link

about loss function #102

Open tommying opened 2 years ago

tommying commented 2 years ago

Hi, appreciate of your great work! But i have some confuse about loss function.

if we have a minibath data: A, A1,A2,A3,B,C,D,E. Ai belongs to one class, and B, C, D, E belongs other 4 class.

the paper said eq2 let more positive data pair contribute to loss function, and positive pair (such as A•A1, A•A2, A•A3) becomes lager, negtive pair (such as: A•B, A•C, A•D) becomes smaller.

The eq2 indicate that the numerator inside log still only have one positive data pair and the denominator have all positive and negtive pair each calculate.

If A is anchor, A1,A2,A3 are positive data, and B,C,D,E are negtive data. the part of eq2 is: (log(A•A1 / ( A•A1 + A•A2 + A•A3 + A•B + A•C + A•D + A•E)+ log(A•A2 / ( A•A1 + A•A2 + A•A3 + A•B + A•C + A•D + A•E)+ log(A•A3 / ( A•A1 + A•A2 + A•A3 + A•B + A•C + A•D + A•E))/ 3

then, when loss is going down, the numerator (such as A•A1) in log is going to get bigger and the denominator is going to get samller. Is that correct? It seems like every elements in denominator is going to get smaller expect A•A1. So that means A•A2, A•A3 in denominator are going to get smaller?Logically, the values of A•A2, A•A3 should be larger.

I'm confused about this. looking forward to your replies!

LQY404 commented 1 year ago

Hi, have you figured out this confusion ? I'm also confused about it. I looked for other versions of this code, and found the same problem...

HobbitLong commented 1 year ago

Would the comment here: https://github.com/HobbitLong/SupContrast/issues/64#issuecomment-1182845137 help answer the question?

For the given example above, there will also be terms 2 & 3 that encourages larger A•A2 and A•A3 in the numerator, respectively.

nlgandnlu commented 1 year ago

Hi, appreciate of your great work! But i have some confuse about loss function.

if we have a minibath data: A, A1,A2,A3,B,C,D,E. Ai belongs to one class, and B, C, D, E belongs other 4 class.

the paper said eq2 let more positive data pair contribute to loss function, and positive pair (such as A•A1, A•A2, A•A3) becomes lager, negtive pair (such as: A•B, A•C, A•D) becomes smaller.

The eq2 indicate that the numerator inside log still only have one positive data pair and the denominator have all positive and negtive pair each calculate.

If A is anchor, A1,A2,A3 are positive data, and B,C,D,E are negtive data. the part of eq2 is: (log(A•A1 / ( A•A1 + A•A2 + A•A3 + A•B + A•C + A•D + A•E)+ log(A•A2 / ( A•A1 + A•A2 + A•A3 + A•B + A•C + A•D + A•E)+ log(A•A3 / ( A•A1 + A•A2 + A•A3 + A•B + A•C + A•D + A•E))/ 3

then, when loss is going down, the numerator (such as A•A1) in log is going to get bigger and the denominator is going to get samller. Is that correct? It seems like every elements in denominator is going to get smaller expect A•A1. So that means A•A2, A•A3 in denominator are going to get smaller?Logically, the values of A•A2, A•A3 should be larger.

I'm confused about this. looking forward to your replies!

I have the same question with you, and i think it is better to remove A•A2, A•A3 in denominator. What do you think of this?