keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.05k stars 19.48k forks source link

Training using a `tf.data.Dataset` and `steps_per_execution` > 32 fails #20344

Closed nicolaspi closed 3 weeks ago

nicolaspi commented 1 month ago

Training using a tf.data.Dataset and steps_per_execution > 32 fails with:

ValueError: An unusually high number of `tf.data.Iterator.get_next()` calls was detected. This suggests that the `for elem in dataset: ...` idiom is used within tf.function with AutoGraph disabled. This idiom is only supported when AutoGraph is enabled.

Reproduction code:

import keras
import tensorflow as tf

x = tf.random.normal((1000, 10))
y = tf.random.uniform((1000,), maxval=2, dtype=tf.int32)

# Create a tf.data.Dataset
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(1000).batch(32)

model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', input_shape=(10,)),
    keras.layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

model.compile(steps_per_execution=33)

model.fit(dataset, epochs=5)
mehtamansi29 commented 1 month ago

Hi @nicolaspi -

Thanks for reporting this issue. Here getting error because tf.function used with AutoGraph disabled. And in keras3 AutoGraph disabled by default. Here can find more details about AutoGraph.

You need to use use Eager execution mode model.compile(steps_per_execution=33,run_eagerly=True) to enable AutoGraph in keras3.

Attached gist for the reference.

nicolaspi commented 1 month ago

Hi @mehtamansi29 Thanks for the answer. The issue arises from a protection heuristic defined in tf.data here . The protection is disabled when using eager mode, but this is not a viable solution due to the performance impact. My solution was to override make_train_function and replace:

@tf.autograph.experimental.do_not_convert
def multi_step_on_iterator(iterator):
    for _ in range(self.steps_per_execution):
        outputs = one_step_on_iterator(iterator)
    return outputs

to

# @tf.autograph.experimental.do_not_convert
def multi_step_on_iterator(iterator):
    for _ in tf.range(self.steps_per_execution):
        outputs = one_step_on_iterator(iterator)
    return outputs

(Notice the range -> tf.range to prevent autograph from unrolling the for loop and make it convert into a while_loop instead).

shkarupa-alex commented 3 weeks ago

Got same issue Manual limiting steps_per_execution with min(32, wanted_steps_per_execution) works well as temporary workaround

google-ml-butler[bot] commented 3 weeks ago

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] commented 3 weeks ago

Are you satisfied with the resolution of your issue? Yes No