Closed yangzhao1230 closed 1 year ago
the same question
Hi you are correct in that the SupConH loss is mistaken in the open-source version, and thanks so much for pointing out! This was a mistake that I made when porting over from the development code to the public software (also the reason why there were two SupConHardLoss() functions). The correct line 32 should be exp_logits_sum = n_pos * torch.log(exp_logits.sum(1))
which is what used to produce the figures in the paper. Here is a short derivation from the equation presented in paper to the correct implementation (without the part about -1/temp since features are L2 normalized).
I just recently built a personal 4090 machine, and will try to release SupConHard CLEAN trained on split100 full swissprot data in a month or two.
In your paper, you mention a logarithmic term in the SupCon-Hard loss
, which seems to be missing in the code: https://github.com/tttianhao/CLEAN/blob/a26877743d685b3aa279dbc5df8c82a7222f5f80/app/src/CLEAN/losses.py#L56. I would greatly appreciate it if you could provide clarification on this matter.