keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.7k stars 19.43k forks source link

Drop in timing performance from Keras 2 to Keras 3 with TensorFlow as backend #19953

Closed mbarbetti closed 1 month ago

mbarbetti commented 3 months ago

After having completed the upgrade of a Python package intended for High Energy Physics flash-simulation (mbarbetti/pidgan) to be compatible with Keras 3, I have noticed a significant drop in timing performance passing from Keras 2 to Keras 3, as also reported in https://github.com/mbarbetti/pidgan/issues/10.

After some iterations, I have been able to reproduce the problem that seems strictly related to Keras 3 (with the TensorFlow backend).

The code that I have used as a reference follows:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import numpy as np

from time import time

chunk_size = int(250_000)
x = np.random.normal(size=(chunk_size, 4))
y = np.random.choice([0.0, 1.0], size=chunk_size, p=[0.5, 0.5])

model = keras.Sequential()
for _ in range(10):
    model.add(keras.layers.Dense(128, activation="relu"))
    model.add(keras.layers.Dropout(rate=0.1))
model.add(keras.layers.Dense(1, activation="sigmoid"))

model.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss=keras.losses.BinaryCrossentropy(label_smoothing=0.05),
    metrics=[keras.metrics.AUC(name="auc")],
    jit_compile=False
)

start = time()
model.fit(x=x, y=y, batch_size=500, epochs=20, validation_split=0.2)
stop = time()

print(f"{stop - start} s")

The above Python script has been executed within two different conda environments: one based on TensorFlow 2.14.1 (with Keras 2.14.0), while the other based on TensorFlow 2.16.1 (with Keras 3.3.3). This exercise has been repeated on three different devices (CPU-only + 2 different GPU cards) and in all the cases I have observed a drop in timing performance passing from Keras 2 to Keras 3.

The device details and the time measured are reported in the following table:

CPU model # cores RAM GPU model GPU partition time on TF2.14 time on TF2.16
AMD EPYC 7282 8 8 GB - - 73.1612 s 79.3706 s
Intel Xeon Gold 5218 8 8 GB Quadro RTX 5000 1/1 51.1899 s 65.2110 s
AMD EPYC 7513 8 8 GB NVIDIA A100 80GB 1/7 46.0913 s 63.2520 s
mbarbetti commented 3 months ago

Let me precise that all the devices are provisioned as virtual machines with OpenStack as a hypervisor

fchollet commented 3 months ago

Note that any benchmark should only measure time starting after the first training step, otherwise you area including the initial startup overhead -- but what you need to measure is the step time.

If you care about performance you should:

mbarbetti commented 3 months ago

Hi @fchollet,

You are right, my naive consideration was that considering enough epochs, such overhead could be neglected but it is probably not the case with 20 epochs only. To stem the problem I have prepared this custom callback intending to measure the average step time in each epoch:

import keras
import numpy as np
from time import time

class BatchTime(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_times = []

    def on_train_batch_begin(self, batch, logs=None):
        self.start = time()

    def on_train_batch_end(self, batch, logs=None):
        self.epoch_times.append(time() - self.start)

    def on_epoch_end(self, epoch, logs=None):
        try:
            step_time = keras.ops.mean(self.epoch_times)
        except AttributeError:
            step_time = np.mean(self.epoch_times)
        self.times.append(step_time)

This callback has been passed within the fit() method and the average step times of the whole training are computed after excluding the first epoch to remove the aforementioned overhead:

# [...]

batch_time = BatchTime()

start = time()
model.fit(x=x, y=y, batch_size=500, epochs=20, validation_split=0.2, callbacks=[batch_time])
stop = time()

step_time = 1e3 * np.mean(batch_time.times[1:])  # in ms

print(f"Average step time: {step_time:.4f} ms")
print(f"Total training time: {stop - start:.4f} s")

This "new configuration" has been used to repeat the exercise originally described and the results are reported as a reference in the following table:

CPU model GPU model average step time* on TF2.14 total time on TF2.14 average step time* on TF2.16 total time on TF2.16
AMD EPYC 7282 - 8.0969 ms 75.5528 s 8.4981 ms 81.4652 s
Intel Xeon Gold 5218 Quadro RTX 5000 4.8602 ms 52.8518 s 6.3447 ms 64.3022 s
AMD EPYC 7513 NVIDIA A100 80GB 4.0156 ms 43.7599 s 5.7462 ms 56.0178 s

*the average step time is measured excluding the first epoch for a more robust performance evaluation

mbarbetti commented 3 months ago

From @fchollet:

If you care about performance you should:

  • Use jit_compile=True
  • Tune your batch size
  • Tune the value for steps_per_execution

Even if I agree with the role of batch-size and number of steps per epoch for performance optimization, this "study" aimed to compare Keras 2 with Keras 3 in terms of timing with common (and "randomly chosen") hyperparameters (i.e., batch-size, number of epochs, dataset size).

For completeness, I have also tried to repeat the usual exercises using jit_compile=True. In this configuration, Keras 3 is able to defeat Keras 2 in terms of timing performance in each of the tested devices. Surprisingly (not really indeed), such a result emerges only by looking at the step time since the total time indicates the opposite direction. As usual, all the details are reported in the following table:

CPU model GPU model average step time* on TF2.14 total time on TF2.14 average step time* on TF2.16 total time on TF2.16
AMD EPYC 7282 - 12.3829 ms 112.9680 s 11.9046 ms 113.3379 s
Intel Xeon Gold 5218 Quadro RTX 5000 1.9773 ms 26.9512 s 1.6777 ms 32.8582 s
AMD EPYC 7513 NVIDIA A100 80GB 1.9033 ms 25.7018 s 1.5725 ms 32.1113 s

*the average step time is measured excluding the first epoch for a more robust performance evaluation

fchollet commented 2 months ago

One more thing: you should pick the right backend.

If you care a lot about overhead timing, then you can use the torch backend which has minimal overhead (but executes slower, typically). For large workloads it is often the case that JAX is the fastest backend.

github-actions[bot] commented 2 months ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] commented 1 month ago

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

google-ml-butler[bot] commented 1 month ago

Are you satisfied with the resolution of your issue? Yes No