keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.97k stars 19.47k forks source link

Gradients Sum in Keras 3 Optimizer with Tensorflow backend #19308

Open kkolli opened 7 months ago

kkolli commented 7 months ago

Hello - Migrating from Keras 2 to Keras 3.0.1 with a TF backend, I noticed that there is no longer an option of skip_gradients_aggregation in the optimizer. Since the default behavior in Keras 2 was False which summed the gradients, not having this option did cause a regression in our model. Source Code or Source code. We use mostly legacy optimizer

If I do the following in train_step, i'm able to reach model parity in graph mode:

    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            loss_dict = self.compute_loss_dict(inputs[0], inputs[1])
            loss = loss_dict['total_loss']

        trainable_weights = self.get_trainable_weights()
        gradients = tape.gradient(loss, trainable_weights)
        gradients = tf.distribute.get_replica_context().all_reduce('sum', gradients)

        self.optimizer.apply_gradients(zip(gradients, trainable_weights))

        return {'loss': loss}

Therefore, I have two questions:

  1. While I looked at both the source code and documentation, is there a way I can avoid writing this code: gradients = tf.distribute.get_replica_context().all_reduce('sum', gradients) as the previous iterations of the optimizers in tf.keras?

  2. Similarly, if there is no alternative, will this be added sometime in the future?

SuryanarayanaY commented 7 months ago

Hi @kkolli ,

If you want to retain Keras2 behaviour with Keras3 you can do it by installing tf_keras package and setting the environment variable os.environ["TF_USE_LEGACY_KERAS"]="1".

  1. is there a way I can avoid writing this code: gradients = tf.distribute.get_replica_context().all_reduce('sum', gradients) as the previous iterations of the optimizers in tf.keras?

Probably I need more context here whether you are using single worker or multiple workers with and code snippet.

Thanks!

kkolli commented 7 months ago

Ah -

os.environ["TF_USE_LEGACY_KERAS"]="1"

Part of our motivation for the migration was to try other backends easily like Jax. Therefore, our goal has been to remove tf.keras completely as an incremental step.

Probably I need more context here whether you are using single worker or multiple workers with and code snippet.

Sure - as some clarifications:

  1. I want to sum the gradients without worrying about tf.distribute in my train step. I'm wondering if this is possible through Keras 3 itself as it was in Keras 2.

  2. We run the model on 1 worker 1 gpu, 1 worker 8 gpu, and multi-worker 8 gpu in graph mode. Therefore, its both multi-worker and multi-gpu setup. Below is a simplified setup on how we are doing gradients sum after the migration:

tf.config.run_functions_eagerly(False)
strategy = tf.distribute.MirroredStrategy() # Will change this to MultiWorkerMirroredStrategy when multi-worker
class CustomModel(keras.Model):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(128, activation='relu')
        self.dense2 = layers.Dense(10, activation='softmax')

    def call(self, inputs, training=False):
        x = self.flatten(inputs)
        x = self.dense1(x)
        return self.dense2(x)

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            loss = self.compiled_loss(y, predictions)

        gradients = tape.gradient(loss, self.trainable_variables)
        gradients = tf.distribute.get_replica_context().all_reduce('sum', gradients)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {'loss': loss}

with strategy.scope():
    custom_model = CustomModel()

    custom_model.compile(optimizer='adam',
                         loss='sparse_categorical_crossentropy',
                         jit_compile=False)

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_train = x_train.reshape((-1, 28, 28))

custom_model.fit(x_train, y_train, epochs=5, steps_per_epoch=100)
kkolli commented 7 months ago

Hi @SuryanarayanaY As part of this, could you confirm that doing a SUM or MEAN of the gradients is expected to be done by the training loop and not part of the optimizer going forward?

kkolli commented 7 months ago

One additional point, this does seem to fail when running on 2 workers with 4 GPU in my test in graph mode (Works in eager). This only seems to work on 1 worker (Mirrored) training in graph mode.

Additionally, this code with gradient_accumulation_steps on 1 worker also fails. Tested on TF 2.16.1 and Keras 3.1 with 4 GPUs

tf.config.run_functions_eagerly(False)
strategy = tf.distribute.MirroredStrategy() # Will change this to MultiWorkerMirroredStrategy when multi-worker
class CustomModel(keras.Model):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(128, activation='relu')
        self.dense2 = layers.Dense(10, activation='softmax')

    def call(self, inputs, training=False):
        x = self.flatten(inputs)
        x = self.dense1(x)
        return self.dense2(x)

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            loss = self.compiled_loss(y, predictions)

        gradients = tape.gradient(loss, self.trainable_variables)
        # gradients = tf.distribute.get_replica_context().all_reduce('sum', gradients)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {'loss': loss}

with strategy.scope():
    custom_model = CustomModel()
    optimizer = keras.optimizers.Adam(learning_rate=0.001, gradient_accumulation_steps=2)
    custom_model.compile(optimizer=optimizer,
                         loss='sparse_categorical_crossentropy',
                         jit_compile=False)

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_train = x_train.reshape((-1, 28, 28))

custom_model.fit(x_train, y_train, epochs=5, steps_per_epoch=100)

Error:


pypi__keras_3_1_0/keras/src/backend/tensorflow/optimizer.py", line 117, in _backend_update_step
        tf.__internal__.distribute.interim.maybe_merge_call(

    RuntimeError: Exception encountered when calling Cond.call().

    `merge_call` called while defining a new graph or a tf.function. This can often happen if the function `fn` passed to `strategy.run()` contains a nested `@tf.function`, and the nested `@tf.function` contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function `fn` uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested `tf.function`s or control flow statements that may potentially cross a synchronization boundary, for example, wrap the `fn` passed to `strategy.run` or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`. If you are subclassing a `tf.keras.Model`, please avoid decorating overridden methods `test_step` and `train_step` in `tf.function`.
fjmscut commented 7 months ago

I found that in tf 2.16.1 & keras 3, there is no related code about gradient all-reduce cross replica, even in model.fit api, optimizer apply_gradient api, which may mean that gradients different across replicas, it is a vary important bug (may be)... can sb. confirm this?

fchollet commented 7 months ago

@kkolli @fjmscut You're right, gradient sum reduction wasn't enabled by default with a tf.distribute.MirroredStrategy. This was a regression. We have switched back the default behavior to sum-reduction.

Please confirm that the fix at HEAD is working for your use case.

@kkolli to be clear, you do not need the skip_aggregation argument, you only need the default behavior to be sum-reduction? We did not reintroduce the argument for now, but we could. It would only apply to TensorFlow if we did it.

kkolli commented 7 months ago

to be clear, you do not need the skip_aggregation argument, you only need the default behavior to be sum-reduction? We did not reintroduce the argument for now, but we could. It would only apply to TensorFlow if we did it.

@fchollet, yes, the argument skip_aggregation is not needed as long as the default behavior is reverted back to sum for my use case. Thank you!

We did not reintroduce the argument for now, but we could. It would only apply to TensorFlow if we did it

I do think it might be beneficial to have this argument back though for TF since some models might not need it. For example, models using Horovod, which will also do SUM or AVG prior to TF optimizer.

fjmscut commented 7 months ago

@kkolli @fjmscut You're right, gradient sum reduction wasn't enabled by default with a tf.distribute.MirroredStrategy. This was a regression. We have switched back the default behavior to sum-reduction.

Please confirm that the fix at HEAD is working for your use case.

@kkolli to be clear, you do not need the skip_aggregation argument, you only need the default behavior to be sum-reduction? We did not reintroduce the argument for now, but we could. It would only apply to TensorFlow if we did it.

Thanks, but it still behave a bit different compare with keras2. In my use case, with mixed precision training & multi worker training, gradient originally return fp32 dtype tensor from tape.gradient api, then for best performance, i cast the gradient to fp16 --> all reduce --> cast back to fp32 --> unscale loss_scale --> apply gradient (with skip_gradients_aggregation=True args). So i need finegrain control of the gradient to satisfy my use case. On the other hand, can keras team introduce a tutorials about mix use of custom training loop & mixed precision training & multi worker training, i think this use case is heavily desired, and currently the tutorials is separated for each training trick, and there are still some bug when use all these trick at the same time

fchollet commented 6 months ago

@fjmscut for your use case, would it be sufficient to add back the argument skip_aggregation?

fjmscut commented 6 months ago

@fjmscut for your use case, would it be sufficient to add back the argument skip_aggregation?

Yes, i need this arg for my use case. By the way, I would very appreciate if keras-team introduce a tutorial about mix use of custom training loop & mixed precision training & multi worker training & XLA, thanks

fchollet commented 6 months ago

multi worker training

You mean multi-GPU rather than multi-worker right? We have no current support for actual Multi-Worker PSS at this time.

custom training loop & mixed precision training

Is this a custom train_step? Support mixed precision takes just one line: make sure to call loss = optimizer.scale_loss(loss) on your loss after computing it. Everything else is built-in.

XLA

In a custom train_step you don't need to do anything specific to support XLA compilation or tf.distribute, it's all built-in.

then for best performance, i cast the gradient to fp16 --> all reduce --> cast back to fp32 --> unscale loss_scale --> apply gradient (with skip_gradients_aggregation=True args). So i need finegrain control of the gradient to satisfy my use case

If what you need is to cast gradients to float16 before reduction, then what I suggest is to write an Optimizer subclass that does this. Here's the part of the code that's relevant:

def _distributed_tf_update_step(
        self, distribution, grads_and_vars, learning_rate
    ):
        grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars)

        def apply_grad_to_update_var(var, grad, learning_rate):
            return self.update_step(grad, var, learning_rate)

        for grad, var in grads_and_vars:
            distribution.extended.update(
                var,
                apply_grad_to_update_var,
                args=(grad, learning_rate),
                group=False,
            )

Just cast the grads before _all_reduce_sum_gradients and cast them back after. No need for a custom training loop here.

fjmscut commented 6 months ago

multi worker training

You mean multi-GPU rather than multi-worker right? We have no current support for actual Multi-Worker PSS at this time.

custom training loop & mixed precision training

Is this a custom train_step? Support mixed precision takes just one line: make sure to call loss = optimizer.scale_loss(loss) on your loss after computing it. Everything else is built-in.

XLA

In a custom train_step you don't need to do anything specific to support XLA compilation or tf.distribute, it's all built-in.

then for best performance, i cast the gradient to fp16 --> all reduce --> cast back to fp32 --> unscale loss_scale --> apply gradient (with skip_gradients_aggregation=True args). So i need finegrain control of the gradient to satisfy my use case

If what you need is to cast gradients to float16 before reduction, then what I suggest is to write an Optimizer subclass that does this. Here's the part of the code that's relevant:

def _distributed_tf_update_step(
        self, distribution, grads_and_vars, learning_rate
    ):
        grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars)

        def apply_grad_to_update_var(var, grad, learning_rate):
            return self.update_step(grad, var, learning_rate)

        for grad, var in grads_and_vars:
            distribution.extended.update(
                var,
                apply_grad_to_update_var,
                args=(grad, learning_rate),
                group=False,
            )

Just cast the grads before _all_reduce_sum_gradients and cast them back after. No need for a custom training loop here.

thanks, i will have a try. But i found that the unscale gradient apply very early at the apply_gradient api, which produce unscale fp32 gradient, if subclass the opt (follow your suggestion), the unscale fp32 gradient will cast to fp16, then which may lead to produce underflow gradient ... (just guess)