tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

Fix simclr loss to work with distributed training. #262

Closed yonigottesman closed 2 years ago

yonigottesman commented 2 years ago

fix #258. Also, while fixing this I noticed the current implementation is wrong even on a single gpu: The paper states: "The final loss is computed across all positive pairs, both (i, j) and (j, i)" meaning the loss is computed twice, once (za,zb) and once (zb,za)

To check that this works I tried to run this code on a multi gpu (4) instance:

from tensorflow_similarity.losses.simclr import SimCLRLoss
import tensorflow as tf

global_batch_size=512

h1 = tf.random.normal((global_batch_size,256))
h2 = tf.random.normal((global_batch_size,256))

l = SimCLRLoss()(h1,h2)
assert l == tf.reduce_sum(SimCLRLoss(reduction=tf.keras.losses.Reduction.NONE)(h1,h2))/global_batch_size

dataset = tf.data.Dataset.from_tensor_slices((h1,h2)).batch(global_batch_size)
bh1,bh2 = next(iter(dataset))
assert l == SimCLRLoss()(bh1,bh2)

strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1","GPU:2", "GPU:3"])
dist_dataset = strategy.experimental_distribute_dataset(dataset)

def replica_fn(input):
    h1,h2 = input
    return SimCLRLoss(reduction=tf.keras.losses.Reduction.NONE)(h1,h2)
x = next(iter(dist_dataset))
result = strategy.run(replica_fn, args=(x,))

l_distributed = tf.reduce_sum(strategy.reduce("SUM", result, axis=None))/global_batch_size
abs(l-l_distributed)

>>> <tf.Tensor: shape=(), dtype=float32, numpy=9.536743e-07>

As you see the result is not bit exact (e-7) but this is the exact result compared to original implementation.

Im not sure how to test this behavior because tests are probably not run in a multi gpu instance. maybe the tf.strategy can context can be mocked somehow just for testing?

google-cla[bot] commented 2 years ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

yonigottesman commented 2 years ago

@owenvallis thanks for clarifying (in the issue thread), indeed the forward pass does the (za,zb) (zb,za). personally I think its better to have the whole loss computation in a single place, its a bit less confusing, it looks exactly like the paper loss and its easier (for me :) ) to reason about the whole averaging over batch size across gpus. Its up to you, if you prefer to call it in the forward pass Ill update the loss, if you prefer to call it through the loss ill update the forward pass. Im still not sure why the scaling and margin is needed, the original implementations doesn't do it. Did you see better results using it??

owenvallis commented 2 years ago

@yonigottesman you raise a good point. The motivation for handling this in the forward pass was to provide support for the different contrastive losses. SimSiam requires 4 inputs, both sets of projector and predictor outputs, while barlow twins only requires a single call using za, zb.

I suppose we could have Simclr and Barlow Twins accept za and zb, and then use the conditional check for Simsiam to handle the 4 input case. I'll take a look at it this weekend.