Open edwardyehuang opened 4 months ago
The current situation is:
1) tf.distribute.ReplicaContext.all_reduce
will directly fail in graph mode of keras 3 + TensorFlow 2.15
2) Any other case (e.g. eager with keras 3.0 or graph with keras 2.15) with the same TensorFlow environment is ok.
Hi @edwardyehuang ,
If you want to use Keras with TF2.x backend, you need to install tf-keras package using pip install tf-keras
and set the environment variable TF_USE_LEGACY_KERAS=1
.
Also could you please confirm the behaviour with TF2.16v which is compatible with Keras3?
Thanks!
Hi @edwardyehuang ,
If you want to use Keras with TF2.x backend, you need to install tf-keras package using
pip install tf-keras
and set the environment variableTF_USE_LEGACY_KERAS=1
.Also could you please confirm the behaviour with TF2.16v which is compatible with Keras3?
Thanks!
Thanks for the reply. I am discussing solving the bug of keras 3, and my code is working fine in keras 2. This issue also happened in TensorFlow 2.16
Update: After removing replica_ctx.all_reduce
, the error is gone. So now the problem is, why replica_ctx.all_reduce
works fine in Keras 2.x but has errors in Keras3.x ? Still investigating
@SuryanarayanaY Note that, this is a bug, not a support
Hi @edwardyehuang ,If possible could we have a reproducible code snippet for this?
If this happened with TF2.16v with keras3 then it may need investigation. If you working on this I will leave it as it is for now.
@fchollet @SuryanarayanaY
A reproducible code snippet is presented below. Make sure you test it on at least 2 GPUs (and set batch_size >= num_gpus).
Remove either conv
or all_reduce
, and the problem will disappear.
The code below is working fine with Keras 2.15, it only has the error in Keras 3
It looks like this is caused by the all_reduce
and tf.placeholder
(for building the model in keras.Layer.__call__
l)
However, given my limited knowledge and time, I'm unable to provide a quick fix. Thus, I need help
import keras
import tensorflow as tf
BATCH_SIZE = 4
tf.get_logger().setLevel('INFO')
strategy = tf.distribute.MirroredStrategy()
# Make model ##########################################################################
with strategy.scope():
class SimpleModel (keras.Model):
def __init__(self, name=None):
super().__init__(name=name)
def build(self, input_shape):
self.l = keras.layers.Conv2D(3, (1, 1), padding='same')
super().build(input_shape)
def call (self, inputs, training=False):
x = inputs
if training:
x = tf.distribute.get_replica_context().all_reduce(tf.distribute.ReduceOp.SUM, x)
x = self.l(x)
return x
m = SimpleModel()
m.compile(
optimizer=keras.optimizers.SGD(learning_rate=1e-3),
loss=keras.losses.MeanSquaredError(),
)
# Make dataset ##########################################################################
def simple_data_generator(num_samples=-1, size=(17, 17)):
counter = 0
while True:
random_data = tf.random.uniform(
shape=size,
minval=-2,
maxval=2,
dtype=tf.float32
)
yield random_data, random_data
counter += 1
if num_samples > 0 and counter >= num_samples:
break
io_shape = (17, 17, 3)
train_dataset = tf.data.Dataset.from_generator(
simple_data_generator,
args=(-1, io_shape),
output_signature=(
tf.TensorSpec(shape=io_shape, dtype=tf.float32),
tf.TensorSpec(shape=io_shape, dtype=tf.float32),
)
)
train_dataset = train_dataset.batch(BATCH_SIZE)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
train_dataset = train_dataset.with_options(options)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Train #################################################################################
m.fit(train_dataset, epochs=1, verbose=1, steps_per_epoch=1000)
Note that I found Keras 3 results 2 more placeholders
in graph.capture
than Keras 2.15, still under investigation.
@qlzh727
same bug when i implement gradient accumulate feature, sb. can help me ? plz .....
@SuryanarayanaY I just noticed the Kaggle provides 2×T4, so here is the code in Kaggle:
@SuryanarayanaY
Hi @edwardyehuang ,
Thanks for the reminder.I have replicated the issue with multi gpu VM environment and attached logs below.
Any update about this? It is the only barrier for me to move to Keras 3, and I can contribute lot of other functions after that.
@sachinprasadhs @fchollet
As an additional update, I believe that this bug is triggered whenever applying the Adam
optimizer in a distributed context (I haven't done an exhaustive search over the optimizers, that's just the one I noticed it on). This bug is currently being masked, because due to the bug described in #19891 the cross-replica reduction to the gradients isn't actually being applied. But if you fix that issue (so that the ReplicaContext.all_reduce
is being applied to the gradients), then you run into this error whenever trying to use the Adam optimizer.
@fchollet @sachinprasadhs @SuryanarayanaY
A humble suggestion: I believe this issue should be the top priority for the Keras team to solve. The existence of this issue makes it impossible for Keras 3 to perform correct distributed training.
@jeffcarp and @kiranbir for visibility.
Here's a smaller repro that doesn't require GPU: https://colab.research.google.com/drive/1uM8rhXuOW9nD3tGvD0gfCf2gY-jzijqg?usp=sharing
I looked into this for a bit - my hunch is there could be a subtle difference between the way the training step tf.function
is handled in the training loops in TF-Keras vs. Keras 3:
@grasskin also mentioned conditionals might not work in replica contexts that call merge_call:
@jeffcarp
Any updates? I just ran the CoLab you provided with the latest tf + keras nightly version, and a new error appeared instead of the old one.
InvalidArgumentError: {{function_node __wrapped__AddN_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Inputs to operation AddN of type AddN must have the same size and shape. Input 0: [1,17,17,3] != input 1: [0,17,17,3] [Op:AddN] name:
Well, currently, I have no idea how to debug this issue because there is no useful information.