keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

`vectorized_map` causes `tf.function` retracing. #241

Open sebastian-sz opened 2 years ago

sebastian-sz commented 2 years ago

Problem description

It seems like applying some layers that use BaseImageAugmentationLayer and self.auto_vectorize=True, over batched input are causing tf.function retracing:

layer = Solarization()  # or Equaliztion()
rng = tf.random.Generator.from_seed(1234)

for _ in range(50):
    dummy_input = rng.uniform(
        shape=(1, 224, 224, 3), minval=0, maxval=255
    )
    layer(dummy_input)

raises

WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7f80a2a544c0> triggered tf.function retracing. (...)
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7f80a2a544c0> triggered tf.function retracing. (...)

Benchmarks

Running simple benchmarks confirms performance degradation with tf.function and batched input:

use_tf_function = False

rng = tf.random.Generator.from_seed(1234)
layer = Solarization()
results = []

if use_tf_function:
    layer.augment_image = tf.function(layer.augment_image, jit_compile=True)

# Warmup
for _ in range(10):
    layer(rng.uniform(shape=(24, 224, 224, 3), maxval=256))

# Benchmark
for _ in range(100):
    dummy_input = rng.uniform(shape=(24, 224, 224, 3), maxval=256)
    start = time.perf_counter()
    layer(dummy_input)
    stop = time.perf_counter()
    results.append(stop-start)

print(tf.reduce_mean(results))

Case 1: auto_vectorize=True

Without tf.function 0.067 ms. With: 0.079 ms.

Case 2: auto_vectorize=False

The issue doesn't pop up with non-batched input e. g. (224, 224, 3) or if one changes self.auto_vectorize=False in the layer.

Setting self.auto_vectorize=False will yield: Withouth tf.function: 0.017 ms With: 0.013 ms.

Case 3: override _batch_augment (if possible)

In case of vectorized operations, the fastest option is still overriding _batch_augment to return self._augment(inputs). This will yield: Without tf.function: 0.0059 ms With: 0.0016 ms

sebastian-sz commented 2 years ago

Improving the performance of Solarization is something I wanted to discuss in another issue.

I wanted to point out that multiple keras_cv preprocessing layers are affected by retracing, when applied on batched input.

bhack commented 2 years ago

I don't think we are going to be impacted by the autovectorizzation/retracing with real use cases:

import tensorflow as tf
from keras_cv.layers.preprocessing import Solarization
from tensorflow.keras.models import Sequential

layer = Solarization()  # or Equaliztion()
rng = tf.random.Generator.from_seed(1234)

from random import randint

model = Sequential()
model.add(layer)
model.build([24,224,224,3])

for x in range(50):
    x = rng.uniform(
        shape=(24,224, 224, 3), minval=0, maxval=255, dtype=tf.float32)
    _ = model.predict(x)

See my comments in https://github.com/tensorflow/tensorflow/issues/42441

LukeWood commented 2 years ago

Thanks for the detailed report @sebastian-sz

FYI @qlzh727

sebastian-sz commented 2 years ago

@bhack fair point - using Sequential and .predict method silences the warnings and unifies inference time to be ~0.024ms regardless of whether self.auto_vectorize is True or False.

This is however a bit slower than calling the layer directly with self.auto_vectorize=False (0.013ms) or native vectorization (0.0016 ms).

Also, model.predict cannot be used inside tf.data.Dataset map function - one needs to rely on __call__ methods. I'm unsure how the execution works inside tf.data.Dataset - I do see the differences in inference time, depending on self.auto_vectorize but there are no warning regarding retracing.

import time
import tensorflow as tf
from keras_cv.layers import Solarization

model = tf.keras.Sequential()
model.add(Solarization())
model.build([24, 224, 224, 3])

rng = tf.random.Generator.from_seed(1234)
ds = tf.data.Dataset.from_tensor_slices([rng.uniform(shape=(24, 224, 224, 3), maxval=256)]).repeat(100)
ds = ds.map(lambda x: model(x))

for _ in ds:
    continue

start = time.perf_counter()
for _ in ds:
    continue
stop = time.perf_counter()

print((stop - start) / 100)
sebastian-sz commented 2 years ago

It seems like wrapping the entire layer in tf.function (even better if with jit_compile=True) also silences the warnings and provides decent performance in eager mode:

@tf.function(jit_compile=True)
def apply(x):
    return layer(x)

0.0015ms for self.auto_vectorize=True 0.0022ms for self.auto_vectorize=False 0.0015ms for native vectorization.

This issue can be closed from my end. If no further comments appear I will close this issue starting next week. Thanks for the help!

bhack commented 2 years ago

Generally It Is not the best solution to benchmark in the loop with predict:

https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call

For controlling the XLA compilation see my prpoposal at: https://github.com/keras-team/keras-cv/issues/165#issuecomment-1083502165

LukeWood commented 2 years ago

I wonder if we should @tf.function our call methods by default to:

a.) mute warnings b.) make performance consistent.

bhack commented 2 years ago

@sebastian-sz Can you try your initial example with the last tf-nightly version?

sebastian-sz commented 2 years ago

@bhack Running with 2.9.0-dev20220329 gives very similar numbers and retracing persists.

I am however happy with the performance from tf.function wrapper.

bhack commented 2 years ago

I wonder if we should @tf.function our call methods by default to:

a.) mute warnings b.) make performance consistent.

Vectorized map is going internally to trace the function of we are in the default eager mode but model are by default tf.function wrapped.

If we want to maintain the critical section eager-compatible we need to automate the conditonal call on standard map_fn in the base class overload we have done (we are in eager mode).

bhack commented 2 years ago

More in general I think that this use of "layer as op" is still a little bit confusing:

https://github.com/keras-team/keras-cv/pull/122#discussion_r803214729

LukeWood commented 2 years ago

It seems like wrapping the entire layer in tf.function (even better if with jit_compile=True) also silences the warnings and provides decent performance in eager mode:

@tf.function(jit_compile=True)
def apply(x):
    return layer(x)

0.0015ms for self.auto_vectorize=True 0.0022ms for self.auto_vectorize=False 0.0015ms for native vectorization.

This issue can be closed from my end. If no further comments appear I will close this issue starting next week. Thanks for the help!

I don’t want to close it yet because I feel we can need to figure out how to effectively communicate this recommendation to users 🤔

LukeWood commented 2 years ago

I wonder if we should @tf.function our call methods by default to: a.) mute warnings b.) make performance consistent.

Vectorized map is going internally to trace the function of we are in the default eager mode but model are by default tf.function wrapped.

If we want to maintain the critical section eager-compatible we need to automate the conditonal call on standard map_fn in the base class overload we have done (we are in eager mode).

We could also @tf.function the base layers call method if needed. Or the augment batch method

bhack commented 2 years ago

We could also @tf.function the base layers call method if needed. Or the augment batch method

It really depends.. do you want to silently be in graph mode with some functions?

As the end user/developer doesn't control the vectorization in the API it is something that you are going to do behind the scene without any notification.

At least model.compile still give the control to the end user for both eager and jit_compile (XLA) with its own args:

compile(
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    steps_per_execution=None,
    jit_compile=None,
    **kwargs
)
bhack commented 2 years ago

/cc @mdanatg I suppose that the situation isn't evolved since Oct 2020 https://github.com/tensorflow/tensorflow/issues/43710#issuecomment-702430322. What do you think?

mdanatg commented 2 years ago

I think we now have better mechanisms to protect against excessive retracing. Is the error coming from a standard Keras layer, or is it a custom one?

bhack commented 2 years ago

It is inherited by the new base layer in Keras:

https://github.com/keras-team/keras/blob/master/keras/layers/preprocessing/image_preprocessing.py#L841-L844

https://github.com/keras-team/keras/blob/master/keras/layers/preprocessing/image_preprocessing.py#L312-L316

bhack commented 2 years ago

As we don't have an object with tf.vectorized_map it is hard to not retrace the function.

So I still believe that is better to automatically call map_fn in eager mode and tf.vectorized_map in graph mode.

bhack commented 2 years ago

Cause when the graph creation it is done implicitly by API design like in tf.data it is documented explicitely:

https://www.tensorflow.org/api_docs/python/tf/data/Dataset?hl=en#map

map_func can accept as arguments and return any type of dataset element. Note that irrespective of the context in which map_func is defined (eager vs. graph), tf.data traces the function and executes it as a graph. To use Python code inside of the function you have a few options:

1) Rely on AutoGraph to convert Python code into an equivalent graph computation. The downside of this approach is that AutoGraph can convert some but not all Python code. 2) Use tf.py_function, which allows you to write arbitrary Python code but will generally result in worse performance than 1). For example:

LukeWood commented 2 years ago

Ok, after lots more digging and time I agree with you @bhack we should apply map_fn in eager, vectorized in graph. We can tackle this after @divyashreepathihalli migrates BaseImageAugmentationLayer to KerasCV. It will be easier to update when in KerasCV