tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

Bug: SimCLRLoss has a mistake in normalizing za and zb #237

Closed kechan closed 2 years ago

kechan commented 2 years ago

Where:

https://github.com/tensorflow/similarity/blob/master/tensorflow_similarity/losses/simclr.py

In def call(...):

if self.use_hidden_norm:
    za = tf.math.l2_normalize(za)
    zb = tf.math.l2_normalize(zb)

Should be:

if self.use_hidden_norm:
    za = tf.math.l2_normalize(za, axis=-1)
    zb = tf.math.l2_normalize(zb, axis=-1)

Cause Doc: https://www.tensorflow.org/api_docs/python/tf/math/l2_normalize

Default axis is None, so the normalization will be over batch as well as embedding dim. We want only norm for embedding vec (ie. axis=-1).

Consequence

SimCLR training failed with loss stuck at an unreasonably high level even with 1 batch of data and ResNet50Sim.

Related See Issue https://github.com/tensorflow/similarity/issues/233 on full detail on the test and investigation behind this.

owenvallis commented 2 years ago

Looks like the same issue in the barlow twins loss. I've changed them both and I'm rerunning the notebook. I'll push the patch once I verify the results are better, but so far it's looking good for SimCLR. Thanks again for finding this.

kechan commented 2 years ago

Glad to hear this helps with debugging the barlow.

Just throwing out another thought unrelated to this issue. Is the dataset shuffle buffer size of 1024 coming from their papers or code. I would be surprised if 3 papers used the same dataset setup. Since the batch size is 512, i thought the shuffle buffer size could be a bit bigger.

I don't have deep enough experience with contrastive training. So this is speculative. I have a hunch that higher quality shuffling may lead to more optimal training that produce better representation. Not sure if there's any research done that put some emphasis on this (or it doesn't matter). I may probably play around with this and see.

owenvallis commented 2 years ago

Fixed in 98ecb5758e3bb46022387ff7f7852d87dc847f35 and 996c6f5a3448c47e310f38c8118ff6883c9269ee and pushed to pypi as 0.15.4

Huge thanks again for catching this @kechan. I ran through the other contrastive losses and I think they should be fine. I'm also going to try and knock out some more test coverage once some of my other projects wrap up soon.

kechan commented 2 years ago

@owenvallis Not a problem. Much thanks belong to you and your team building this framework out, and fast turn around in fixing important issues. This has already helped me bootstrap myself into these recent contrastive model technique, which I have procrastinated on for a long while.