Closed innat closed 1 year ago
cc. @andreped https://github.com/andreped/GradientAccumulator/issues/5
Hello, @innat!
I have not had the time to add multi-GPU support to GradientAccumulator, but can make an attempt at it today.
However batch training + gradient accumulation + mixed preicision works seemlessly.
I have been using it for various projects already.
Thanks for your response. I like to extend the above gist (custom fit + override the train step func) for multi-gpu (and hopefully tpu). Also, I've updated my query on stack with bounty (HERE).
I noticed that @stefan-falk also faced similar error https://github.com/tensorflow/tensorflow/issues/50454 that I reported above. He tried many ways, HERE, it may give some insight.
Regarding the mixed precison, as I said, I was wondersing if we need to call opt.get_scaled_loss
and opt.get_unscaled_gradients
. In official doc, it is said to do so in custom loop training only.
cc. @MrForExample
Regarding the mixed precison, as I said, I was wondersing if we need to call
opt.get_scaled_loss
andopt.get_unscaled_gradients
. In official doc, it is said to do so in custom loop training only.
Hmm, that's interesting. However, can't it be argued that overloading the train_step
actually introduces a custom training loop? I mean, that method could do anything at this point. But I agree, it is not so clear from the documentations. If anyone wishes to read further, see here.
Will start on the multi-GPU support now. Did you have a gist I could use for debugging/testing, @innat? Also note that the GradientAccumulator (without multi-GPU) also works with TPUs. But I am only able to run tests locally, as I doubt I am allowed to use multi-GPUs in a single colab session.
Here is a gist, (also mentioned above). For multi-gpu (and tpu), you can use kaggle env. It now provides multi-gpu env.
As mentioned in the other ticket Graphcore had a design as optimizer wrapper including cross replica:
/Cc @georgepaw
As the error suggests aggregating gradients inside nested tf.function which is not yet supported as per the error.
RuntimeError:
merge_call
called while defining a new graph or a tf.function. This can often happen if the functionfn
passed tostrategy.run()
contains a nested@tf.function
, and the nested@tf.function
contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the functionfn
uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nestedtf.function
s or control flow statements that may potentially cross a synchronization boundary, for example, wrap thefn
passed tostrategy.run
or the entirestrategy.run
inside atf.function
or move the control flow out offn
. If you are subclassing atf.keras.Model
, please avoid decorating overridden methodstest_step
andtrain_step
intf.function
Hence I tried the code in Eager mode by setting model.compile(run_eagerly=True) to check and it works fine in Eager mode.Please refer to attached log below.First the code works with model.compile(run_eagerly=True) which executes fine and the same code tested with model.compile(run_eagerly=False) and execution terminated immediately with runtime error.
This testing done on 2-GPU machine
(tf) suryanarayanay@ubuntu-20-04-test-gpu-surya:~$ python 17429_grad_accumulation_on_multi_gpu_r1.py
2023-01-20 05:39:37.545128: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-20 05:39:38.049160: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-01-20 05:39:40.134530: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/suryanarayanay/miniconda3/envs/tf/lib/
2023-01-20 05:39:40.134656: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/suryanarayanay/miniconda3/envs/tf/lib/
2023-01-20 05:39:40.134680: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-01-20 05:39:44.877039: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38285 MB memory: -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0
2023-01-20 05:39:44.880513: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 38397 MB memory: -> device: 1, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:05.0, compute capability: 8.0
run_eagerly=True in model.compile()
Epoch 1/3
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
1/10000 [..............................] - ETA: 7:36:28 - loss: 2.3965 - accuracy: 0.0000e+00WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
3/10000 [..............................] - ETA: 7:06 - loss: 2.2653 - accuracy: 0.1111 WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
9/10000 [..............................] - ETA: 7:01 - loss: 2.3390 - accuracy: 0.14812023-01-20 05:39:51.483722: W tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:115] *** WARNING *** You are using ptxas 10.1.243, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2023-01-20 05:39:51.485528: W tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:234] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.0
2023-01-20 05:39:51.485549: W tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:237] Used ptxas at ptxas
2023-01-20 05:39:51.694080: W tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:115] *** WARNING *** You are using ptxas 10.1.243, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2023-01-20 05:39:52.020570: W tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:115] *** WARNING *** You are using ptxas 10.1.243, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
WARNING:tensorflow:5 out of the last 5 calls to <function _BaseOptimizer._update_step_xla at 0x7f6a1c0b3160> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function _BaseOptimizer._update_step_xla at 0x7f6a1c0b3160> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
2023-01-20 05:39:52.275011: W tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:115] *** WARNING *** You are using ptxas 10.1.243, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
48/10000 [..............................] - ETA: 13:48 - loss: 2.1689 - accuracy: 0.2639WARNING:tensorflow:5 out of the last 5 calls to <function _apply_all_reduce.<locals>._all_reduce at 0x7f6a145b11f0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
58/10000 [..............................] - ETA: 12:41 - loss: 2.1184 - accuracy: 0.3190WARNING:tensorflow:6 out of the last 6 calls to <function _apply_all_reduce.<locals>._all_reduce at 0x7f6a145101f0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
10000/10000 [==============================] - 459s 46ms/step - loss: 0.2902 - accuracy: 0.9181
Epoch 2/3
10000/10000 [==============================] - 456s 46ms/step - loss: 0.1310 - accuracy: 0.9618
Epoch 3/3
10000/10000 [==============================] - 457s 46ms/step - loss: 0.0934 - accuracy: 0.9728
run_eagerly=False in model.compile()
Epoch 1/2
Traceback (most recent call last):
File "/home/suryanarayanay/17429_grad_accumulation_on_multi_gpu_r1.py", line 94, in <module>
custom_model.fit(x_train, y_train, batch_size=6, epochs=2, verbose = 1)
File "/home/suryanarayanay/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_file6ty2k0r9.py", line 15, in tf__train_function
retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
File "/home/suryanarayanay/17429_grad_accumulation_on_multi_gpu_r1.py", line 42, in train_step
tf.cond(
File "/home/suryanarayanay/17429_grad_accumulation_on_multi_gpu_r1.py", line 49, in apply_accu_gradients
self.optimizer.apply_gradients(zip(self.gradient_accumulation, self.trainable_variables))
RuntimeError: in user code:
File "/home/suryanarayanay/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1249, in train_function *
return step_function(self, iterator)
File "/home/suryanarayanay/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1233, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/home/suryanarayanay/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py", line 1222, in run_step **
outputs = model.train_step(data)
File "/home/suryanarayanay/17429_grad_accumulation_on_multi_gpu_r1.py", line 42, in train_step
tf.cond(
File "/home/suryanarayanay/17429_grad_accumulation_on_multi_gpu_r1.py", line 49, in apply_accu_gradients
self.optimizer.apply_gradients(zip(self.gradient_accumulation, self.trainable_variables))
File "/home/suryanarayanay/miniconda3/envs/tf/lib/python3.9/site-packages/keras/mixed_precision/loss_scale_optimizer.py", line 1301, in apply_gradients
grads_and_vars = self._optimizer.aggregate_gradients(grads_and_vars)
File "/home/suryanarayanay/miniconda3/envs/tf/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1105, in aggregate_gradients
return optimizer_utils.all_reduce_sum_gradients(grads_and_vars)
File "/home/suryanarayanay/miniconda3/envs/tf/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/utils.py", line 42, in all_reduce_sum_gradients
reduced = tf.distribute.get_replica_context().merge_call(
RuntimeError: `merge_call` called while defining a new graph or a tf.function. This can often happen if the function `fn` passed to `strategy.run()` contains a nested `@tf.function`, and the nested `@tf.function` contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function `fn` uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested `tf.function`s or control flow statements that may potentially cross a synchronization boundary, for example, wrap the `fn` passed to `strategy.run` or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`. If you are subclassing a `tf.keras.Model`, please avoid decorating overridden methods `test_step` and `train_step` in `tf.function`.
(tf) suryanarayanay@ubuntu-20-04-test-gpu-surya:~$
@innat Whether eager mode is OK for you though it has performance issue but it seems this works fine here.
@SuryanarayanaY Thanks for the test. Please note, eager mode is a nice option to test the code with some circumstances. But It should not be treated as a solutions as it brings lots of performance cost (raised a ticket regarding cost of eager mode).
[Info add] SRC https://keras.io/api/optimizers/
... the nested @tf.function contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), ...
Optimizer.apply_gradients(
grads_and_vars, name=None, skip_gradients_aggregation=False, **kwargs
)
skip_gradients_aggregation
: If true, gradients aggregation will not be performed inside optimizer. Usually this arg is set to True when you write custom code aggregating gradients outside the optimizer.@innat the root cause of this error is the tf.cond
in your train_step
.
One option is to work around using tf.cond
by using some gradient masking to only apply gradients every n
th batch.
Here's a modified version of your colab which uses this approach and seems to be working.
It's probably marginally less performant than if the graph could be fully compiled with the conditional in it, but merging a subgraph which has a conditional on a synchronized variable is (I think) a fundamental limitation of running TF in distributed mode.
@ianstenbit thanks for the reply.
The updated code of yours does run on mult-gpu. However, I've noticed a noticeable performance drop when I compared on a single device with n_gradients=1
.
From my gist,
Epoch 1/3
10000/10000 - 23s - loss: 0.2041 - accuracy: 0.9387
Epoch 2/3
10000/10000 - 23s - loss: 0.0937 - accuracy: 0.9708
Epoch 3/3
10000/10000 - 23s - loss: 0.0667 - accuracy: 0.9791
<keras.callbacks.History at 0x7f983006fe50>
with yours
Epoch 1/3
10000/10000 - 68s - loss: 0.6961 - accuracy: 0.8416
Epoch 2/3
10000/10000 - 22s - loss: 0.6387 - accuracy: 0.8541
Epoch 3/3
10000/10000 - 22s - loss: 0.6387 - accuracy: 0.8541
<keras.callbacks.History at 0x7f97d41fd1d0>
@innat looks like I had a silly mistake in the line of code where I was zeroing out gradients after applying them
I had
self.gradient_accumulation[i].assign(-1 * logical_grads[i])
but it should have been
self.gradient_accumulation[i].assign_add(-1 * logical_grads[i])
After making these changes, I got much closer results to your original results.
It occurred to me, though, that to avoid any rounding errors it's probably better to use
self.gradient_accumulation[i].assign(
tf.cast(tf.logical_not(should_apply), self.gradient_accumulation[i].dtype)
* self.gradient_accumulation[i]
)
It's still not precisely the same numerically as your original implementation. I think this may be because calling optimizer.apply_gradients
even with all-zero gradients is likely adjusting the optimizer state.
Thanks for the update. Could you please check with multiple epoch (ie.10). I observe that the loss and accuracy don't chnage after 2 epoch. Tested with n_gradients=1
.
Thanks for the update. Could you please check with multiple epoch (ie.10). I observe that the loss and accuracy don't chnage after 2 epoch. Tested with
n_gradients=1
.
Yes I see this behavior, and I think it's probably due to calling optimizer.apply_gradients
with zero gradients unnecessarily. I am tinkering to find a potential workaround.
I think in order to correctly perform gradient accumulation, you'd likely need to subclass Optimizer
and encapsulate the logic in your optimizer.
This seems like a constraint of tf.distribute
that we can't work around fully in the Keras train_step
, so from the Keras POV I think there's nothing to be done.
@rchao to confirm
Thanks Ian. Yes, this appears not supported by tf.distribute at this time, and I would recommend filing an issue on tf.distribute if you would like such support.
This seems like a constraint of tf.distribute that we can't work around fully in the Keras train_step, so from the Keras POV I think there's nothing to be done.
Here, the aim is to make it possible to execute in within custom fit (overriding the train_step
). I don't want to subclass the optimizer in the moment. The gist I shared works pretty well on single strategy. The problem arrise for mult-gpu cases as some feature is not supported.
... and I would recommend filing an issue on tf.distribute if you would like such support.
@rchao could you please create an issue . Or, this technique should be supprted https://github.com/keras-team/tf-keras/issues/107 cc @chenmoneygithub
@4uiiurz1 I read on SO that you extened this technique for multi-gpu support. Could you please give some feedback regarding that? Thanks.
I don't want to subclass the optimizer in the moment.
Is there a specific reason why you don't want to wrap the optimizer?
The main reason why I never did that was that I failed to find a working implementation. I found quite a few attempts, some even run (to an extent), but when running a simple benchmark, training results were quite different from regular batch training.
Just now, I managed to get a optimizer wrapper working (see here). This was based on the work by @stefan-falk and @fsx950223. At least it yields extremely similar results to regular batch training.
If you wish to try it out, there is a test script here, in the GradientAccumulator repo.
I was unable to test multi-GPU support, as I do not have access to one until tomorrow. But I could update you on the manner, likely tomorrow. Note that right now, only SGD is supported. Will need to debug why dynamic optimizers such as Adam are not working as well as SGD. I'm not observing the same with the train_step
overload approach.
Is there a specific reason why you don't want to wrap the optimizer?
I don't mind to use that but I strongly prefer to override train step. Adding new ticket https://github.com/tensorflow/tensorflow/issues/59487
I don't mind to use that but I strongly prefer to override train step.
No worries.
If anyone is interested in playing around with the optimizer wrapper solution, here is a gist demonstrating that the optimizer wrapping solution works with tf.distribute.MirroredStrategy
.
I don't have access to multiple GPUs atm, but perhaps someone else has and is interested to try.
@andreped
I quicky tested on kaggle (2x T4 GPU) with TF 2.6.4, got the following error.
FailedPreconditionError: 2 root error(s) found.
(0) Failed precondition: Could not find variable _AnonymousVar40. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status=Not found: Resource localhost/_AnonymousVar40/N10tensorflow3VarE does not exist.
[[{{node cond_1/then/_12/cond_1/GAOptimizerWrapper/GAOptimizerWrapper/update_3/update_0/StatefulPartitionedCall/cond/then/_306/cond/Cast/ReadVariableOp}}]]
[[Func/cond/then/_0/cond/cond/then/_134/cond/cond/cond/then/_330/input/_471/_140]]
(1) Failed precondition: Could not find variable _AnonymousVar40. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status=Not found: Resource localhost/_AnonymousVar40/N10tensorflow3VarE does not exist.
[[{{node cond_1/then/_12/cond_1/GAOptimizerWrapper/GAOptimizerWrapper/update_3/update_0/StatefulPartitionedCall/cond/then/_306/cond/Cast/ReadVariableOp}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_2656]
Function call stack:
train_function -> train_function
You will not face this error in colab (with tf 2.6.4).
I quicky tested on kaggle (2x T4 GPU) with TF 2.6.4, got the following error.
Oh, OK. Nice to know! Will have to do some further debugging. Cheers :] Anyways, the gist serves as a nice foundation for making a proper solution.
I was able to reproduce the bug in Kaggle, @innat. Love that you have access to two GPUs for free on Kaggle!
I've shared my Kaggle notebook here, if anyone wishes to debug this further. Any ideas would be much obliged!
It seems to work just fine with one GPU, but fails during gradient update with multiple in MirroredStrategy.
Note that switching to tf 2.8.0 yields a different error, which might be easier for some of you to unravel:
Node: 'cond/ResourceApplyGradientDescent'
3 root error(s) found.
(0) INVALID_ARGUMENT: alpha is not a scalar: [0]
[[{{node cond/ResourceApplyGradientDescent}}]]
(1) INVALID_ARGUMENT: alpha is not a scalar: [0]
[[{{node cond/ResourceApplyGradientDescent}}]]
[[div_no_nan_1/CollectiveReduceV2_3/_137]]
(2) INVALID_ARGUMENT: alpha is not a scalar: [0]
[[{{node cond/ResourceApplyGradientDescent}}]]
@ianstenbit
Could you please provide some details about steps_per_execution
, from docs, it says
steps_per_execution: Int. Defaults to 1. The number of batches to run during each tf.function call. Running multiple batches inside a single tf.function call can greatly improve performance on TPUs or small models with a large Python overhead. At most, one full epoch will be run each execution. If a number larger than the size of the epoch is passed, the execution will be truncated to the size of the epoch. Note that if steps_per_execution is set to N, Callback.on_batch_begin and Callback.on_batch_end methods will only be called every N batches (i.e. before/after each tf.function execution).
Is it possible alternative of gradient accumulation techniques? What does it mean when it says number of batches to run during each tf.function call. For each batch, do the corresponding gradient accumulated?
class CustomModel(keras.Model):
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
print()
print(x.shape, y.shape, tf.shape(x)[0].numpy())
print()
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
x = np.random.random((100, 32))
y = np.random.random((100, 1))
with strategy.scope():
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(
optimizer="adam",
loss="mse",
metrics=["mae"],
steps_per_execution = 1,
run_eagerly=1
)
model.fit(
x, y,
validation_data=(x,y),
epochs=1,
batch_size=32
)
With gpu=2
and steps_per_execution=1
, it gives as follows (1st steps).
(16, 32) (16, 1) 16
(16, 32) (16, 1) 16
1/4 [======>.......................] - ETA: 0s - loss: 1.3174 - mae: 1.0198
And with steps_per_execution=2
, it gives as follows (1st, 2nd steps at a time.)
(16, 32) (16, 1) 16
(16, 32) (16, 1) 16
(16, 32) (16, 1) 16
(16, 32) (16, 1) 16
2/4 [==============>...............] - ETA: 0s - loss: 0.2012 - mae: 0.3581
It looks like a possible alternative of gradient accumulation technique. I like to know what happen s when steps_per_execution=N
for M
gpus inside the train_step
function. How the losses are calculated for each call for N > 1
.
Also, does steps_per_execution
apply to validation_data
? Why not available in model.evaluate
or model.predict
?
Hi @innat
steps_per_execution
does apply to model.evaluate
and model.predict
when specified in model.compile
.
steps_per_execution
does not cause gradient accumulation. Weights are updated once per batch, even when two steps are executed inside the same tf.function
.
If steps_per_execution=N
and you have M
GPUs, every time the host device sends a unit of work to any GPU, it will send N
batches.
I think the main reason for the problem is that tensorflow does not allow control flow containing any synchronization op in the replica context wrapped by tf.function. I guess that tf.function will build graph for each branch of control flow, so that each replica may enter different branches which will cause conflicts in synchronization.
The key is "replica context", so switching to executing tf.cond
in the cross-replica context will directly solve the problem without any optimizer-specific wrapper or modifying optimizer internal implementation.
Here is an example:
def apply_accumulated_gradients(grads_and_vars):
# actually apply gradients logic
pass
should_apply = ... # a boolean flag
def apply_gradients_cross_replica(strategy, grads_and_vars):
def _apply_fn():
strategy.extended.call_for_each_replica(
apply_accumulated_gradients, args=(grads_and_vars,))
tf.cond(should_apply, _apply_fn, lambda: None)
# execute control flow with synchronization op in the cross-replica context
tf.distribute.get_replica_context().merge_call(
apply_gradients_cross_replica, args=(grads_and_vars,))
@AIGideon Have you tested it? Also, gradient accumulation is now supported (didn't test though). https://github.com/keras-team/keras/pull/18951
@innat Yes, I tested that it can work perfectly with tf.distribute.MirroredStrategy
(tensorflow version 2.2~2.12). Other distribute strategy remain to be tested.
I don't know whether keras3 implementation solve this problem, but switching to a cross-replica context in a replica context is a very common usage in tf. I just wonder why the keras2 (tf-keras) community has been troubled with the implementation of gradient accumulation for a long time and no solid solution has ever been given. I've seen other implementations from the community, and most of them are based on the following three approaches to avoid control flow:
inner_optimizer.apply_gradients()
, and reset them back to to the original state after updating. This can indeed achieve corrent results in theory, but the backup and recovery process requires several copies of all model variables and optimizer states. The increase in memory usage may be more than the memory saved by gradient accumulation itself, which defeats the purpose of gradient accumulation to save memory.apply_gradients()
code or logic so that the variable update operation (like var.assign()
) does not actually modify its value during accumulation phase. There are two ways to achieve this:
apply_gradients()
of each optimizer subclass or the public method of optimizer base class. This is too complex and difficult to switch between keras (tf-keras) version.DummyUpdateVariable
as a replacement of tf.Variable
(similar to AutoCastVariable
used in keras mixed precision, wrap tf.Variable to override some methods), rewrite its update methods like assign()
, assign_add()
so that it can perform variable update op without modifying its value. Then we wrap all optimizer state variables and model trainable variables passed to inner optimizer by DummyUpdateVariable
. This approach avoids modifying optimizer code and is therefore decoupled from the keras (tf-keras) version.Backup to topic, I think the best way to implement gradient accumulation in keras2 (tf-keras) is to organize my above example code into a generic OptimizerWrapper that can receive any tf.keras.optimizers.Optimizer
instance and does not require any specific logic/code modifications.
Could you please share a complete gist with your approach?
@innat OK, I will give an example based on tensorflow==2.12.0 (which take keras new optimizer api under keras/optimizers/optimizer_experimental/ as the default optimizer instead of optimizer_v2)
import tensorflow as tf
from typing import Iterable, List, Tuple
class GradientAccumulationOptimizer(tf.keras.optimizers.Optimizer):
def __init__(
self,
optimizer: tf.keras.optimizers.Optimizer,
gradient_accumulation_steps: int = 1,
name: str = 'GradientAccumulationOptimizer',
**kwargs
):
super().__init__(name=name, **kwargs)
self.optimizer = optimizer
self.gradient_accumulation_steps = gradient_accumulation_steps
def apply_gradients(
self,
grads_and_vars: Iterable[Tuple[tf.Tensor, tf.Variable]],
*args,
**kwargs
):
grads_and_vars = list(grads_and_vars)
vars = [var for _, var in grads_and_vars]
if not hasattr(self, '_built') or not self._built:
self.build(vars)
self.step.assign_add(1)
should_apply = tf.equal(self.step % self.gradient_accumulation_steps, 0)
# update accumulated gradients
self._update_accumulated_grads(grads_and_vars)
# apply gradients
def _cross_replica_apply_gradients(strategy, grads_and_vars):
def _apply_fn():
strategy.extended.call_for_each_replica(
self._apply_accumulated_grads,
args=(grads_and_vars, *args), kwargs=kwargs)
tf.cond(should_apply, _apply_fn, lambda: None)
tf.distribute.get_replica_context().merge_call(
_cross_replica_apply_gradients, args=(grads_and_vars,))
# reset accumulated gradients if necessary
tf.cond(should_apply, self._reset_accumulated_grads, lambda: None)
return self.optimizer.iterations
def _update_accumulated_grads(
self,
grads_and_vars: List[Tuple[tf.Tensor, tf.Variable]]
):
for i, (grad, _) in enumerate(grads_and_vars):
self.accumulated_grads[i].assign_add(grad)
def _apply_accumulated_grads(
self,
grads_and_vars: List[Tuple[tf.Tensor, tf.Variable]],
*args,
**kwargs
):
accumulated_grads_and_vars = [
(
self.accumulated_grads[i] / tf.cast(
self.gradient_accumulation_steps,
self.accumulated_grads[i].dtype),
var
)
for i, (_, var) in enumerate(grads_and_vars)
]
self.optimizer.apply_gradients(
accumulated_grads_and_vars, *args, **kwargs)
def _reset_accumulated_grads(self):
for grad in self.accumulated_grads:
grad.assign(tf.zeros_like(grad))
def build(self, var_list: List[tf.Variable]):
super().build(var_list)
self.optimizer.build(var_list)
self.accumulated_grads = [
tf.Variable(
initial_value=tf.zeros_like(var),
trainable=False,
aggregation=tf.VariableAggregation.NONE)
for var in var_list
]
self.step = tf.Variable(
initial_value=0, trainable=False, dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
self._built = True
You can use it to wrap any optimizer like SGD
, Adam
, and this wrapper itself can also be wrapped by LossScaleOptimizer
(this usually happens automatically in model.compile()
when mixed precision is enabled).
I haven't tried later tensorflow versions, but if you use an earlier verison, some modifications may be needed:
self.build()
but self._create_all_weights()
.apply_gradients()
method needs to return an op instead of a tensor in graph mode to be compatible with LossScaleOptimizer
logic at that time. You may modify the return line like this:
if tf.executing_eagerly():
return self.optimizer.iterations
else:
return self.optimizer.iterations.assign_add(0, read_value=False)
@AIGideon Thanks. I was trying to achieve this with custom fit method.
Thanks to @AIGideon, @innat and @andreped
I could implement GAOptimizer by modifying @AIGideon and referring to @andreped's implementation.
This GAOptimizer
System information.
Describe the problem
I have code that works fine but gives the following error if I use
with strategy.scope()
.Describe the expected behavior
I think, It should work.
Standalone code to reproduce the issue
The code is for gradient accumulation techniques. Here it is done by overriding the
trian_step
withfit
method. This code works fine (as said above) withoutwith strategy.scope()
. Now, I like to use it for multi-gpu cases, and so I use strategy scope but ened up the the above mentioned error.Gist.
Follow-up Questions
BATCH_SIZE = 32 * strategy.num_replicas_in_sync
inside thetrain_step
method? Or it will be handled auto?LossScaleOptimizer
and useoptimizer.get_scaled_loss(loss)
andoptimizer.get_unscaled_gradients(gradients)
. But the official documentation talks about normalfit
and custom loop training cases. In case of custom loop, it's suggested to wrap the optimizer and scale the loss and gradient but what about the combination offit
and custom loop (overridingtrain_step
)? Does it sill need to wrap the optimizer and scale the loss and gradient or it will be handled by the API?Others: https://github.com/keras-team/tf-keras/issues/107 cc @chenmoneygithub @nikitamaia @bhack