keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.14k stars 19.49k forks source link

Fix profiling for Tensorflow and JAX #20450

Closed nicolaspi closed 3 weeks ago

nicolaspi commented 3 weeks ago

Fix profiling for TensorFlow and JAX. Profiling with JAX backend on Python version < 3.12 raises segmentation faults on my local setup. The profiling is deactivated for JAX when this requirement is not met. Also note that profiling the first batch with jit_compile=True causes TensorBoard to hang indefinitely. Example of profiling traces using the following code snippet:

Tensorflow: tf_profile

JAX: jax_profile


import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf  # For tf.data
import tensorflow_datasets as tfds

from keras import layers
from keras.applications import ResNet50 as app_model
from keras.src import callbacks

IMG_SIZE = 224
BATCH_SIZE = 64
steps_per_execution = 1
dataset_size = BATCH_SIZE * max(steps_per_execution, 32)

dataset_name = "stanford_dogs"
(ds_train, ds_test), ds_info = tfds.load(
    dataset_name, split=["train", "test"], with_info=True, as_supervised=True
)
NUM_CLASSES = ds_info.features["label"].num_classes

size = (IMG_SIZE, IMG_SIZE)
ds_train = ds_train.take(dataset_size * (len(ds_train) // dataset_size)).map(
    lambda image, label: (tf.image.resize(image, size), label)
)
ds_test = ds_test.take(dataset_size * (len(ds_test) // dataset_size)).map(
    lambda image, label: (tf.image.resize(image, size), label)
)

# One-hot / categorical encoding
def input_preprocess(image, label):
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

ds_train = ds_train.map(
    input_preprocess, num_parallel_calls=tf.data.AUTOTUNE
)
ds_train = ds_train.batch(batch_size=BATCH_SIZE, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    input_preprocess, num_parallel_calls=tf.data.AUTOTUNE
)
ds_test = ds_test.batch(batch_size=BATCH_SIZE, drop_remainder=True).prefetch(
    tf.data.AUTOTUNE
)

model = app_model(
    include_top=True,
    weights=None,
    classes=NUM_CLASSES,
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
)

model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"],
    steps_per_execution=steps_per_execution,
    jit_compile=True,
)

# model.summary()
logdir = "logs/"
tb_cbk = callbacks.TensorBoard(
    logdir, histogram_freq=1, profile_batch="3,4", write_graph=False
)

epochs = 2
hist = model.fit(
    ds_train,
    epochs=epochs,
    validation_data=ds_test,
    callbacks=[tb_cbk],
)
codecov-commenter commented 3 weeks ago

Codecov Report

Attention: Patch coverage is 65.78947% with 13 lines in your changes missing coverage. Please review.

Project coverage is 81.99%. Comparing base (272bb90) to head (6eebff5).

Files with missing lines Patch % Lines
keras/src/backend/jax/tensorboard.py 38.46% 8 Missing :warning:
keras/src/callbacks/tensorboard.py 70.58% 1 Missing and 4 partials :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #20450 +/- ## ========================================== - Coverage 82.01% 81.99% -0.02% ========================================== Files 514 515 +1 Lines 47239 47271 +32 Branches 7413 7421 +8 ========================================== + Hits 38741 38762 +21 - Misses 6704 6712 +8 - Partials 1794 1797 +3 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras/pull/20450/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras](https://app.codecov.io/gh/keras-team/keras/pull/20450/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `81.85% <63.15%> (-0.02%)` | :arrow_down: | | [keras-jax](https://app.codecov.io/gh/keras-team/keras/pull/20450/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `64.89% <34.21%> (-0.02%)` | :arrow_down: | | [keras-numpy](https://app.codecov.io/gh/keras-team/keras/pull/20450/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `59.83% <23.68%> (-0.03%)` | :arrow_down: | | [keras-tensorflow](https://app.codecov.io/gh/keras-team/keras/pull/20450/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `65.90% <57.89%> (-0.01%)` | :arrow_down: | | [keras-torch](https://app.codecov.io/gh/keras-team/keras/pull/20450/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `64.83% <26.31%> (-0.03%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.