Open sebastian-sz opened 2 years ago
Improving the performance of Solarization
is something I wanted to discuss in another issue.
I wanted to point out that multiple keras_cv
preprocessing layers are affected by retracing, when applied on batched input.
I don't think we are going to be impacted by the autovectorizzation/retracing with real use cases:
import tensorflow as tf
from keras_cv.layers.preprocessing import Solarization
from tensorflow.keras.models import Sequential
layer = Solarization() # or Equaliztion()
rng = tf.random.Generator.from_seed(1234)
from random import randint
model = Sequential()
model.add(layer)
model.build([24,224,224,3])
for x in range(50):
x = rng.uniform(
shape=(24,224, 224, 3), minval=0, maxval=255, dtype=tf.float32)
_ = model.predict(x)
See my comments in https://github.com/tensorflow/tensorflow/issues/42441
Thanks for the detailed report @sebastian-sz
FYI @qlzh727
@bhack fair point - using Sequential and .predict
method silences the warnings and unifies inference time to be ~0.024ms regardless of whether self.auto_vectorize
is True or False.
This is however a bit slower than calling the layer directly with self.auto_vectorize=False
(0.013ms) or native vectorization (0.0016 ms).
Also, model.predict
cannot be used inside tf.data.Dataset
map
function - one needs to rely on __call__
methods. I'm unsure how the execution works inside tf.data.Dataset
- I do see the differences in inference time, depending on self.auto_vectorize
but there are no warning regarding retracing.
import time
import tensorflow as tf
from keras_cv.layers import Solarization
model = tf.keras.Sequential()
model.add(Solarization())
model.build([24, 224, 224, 3])
rng = tf.random.Generator.from_seed(1234)
ds = tf.data.Dataset.from_tensor_slices([rng.uniform(shape=(24, 224, 224, 3), maxval=256)]).repeat(100)
ds = ds.map(lambda x: model(x))
for _ in ds:
continue
start = time.perf_counter()
for _ in ds:
continue
stop = time.perf_counter()
print((stop - start) / 100)
It seems like wrapping the entire layer in tf.function
(even better if with jit_compile=True) also silences the warnings and provides decent performance in eager mode:
@tf.function(jit_compile=True)
def apply(x):
return layer(x)
0.0015ms for self.auto_vectorize=True
0.0022ms for self.auto_vectorize=False
0.0015ms for native vectorization.
This issue can be closed from my end. If no further comments appear I will close this issue starting next week. Thanks for the help!
Generally It Is not the best solution to benchmark in the loop with predict:
https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call
For controlling the XLA compilation see my prpoposal at: https://github.com/keras-team/keras-cv/issues/165#issuecomment-1083502165
I wonder if we should @tf.function our call methods by default to:
a.) mute warnings b.) make performance consistent.
@sebastian-sz Can you try your initial example with the last tf-nightly version?
@bhack Running with 2.9.0-dev20220329
gives very similar numbers and retracing persists.
I am however happy with the performance from tf.function
wrapper.
I wonder if we should @tf.function our call methods by default to:
a.) mute warnings b.) make performance consistent.
Vectorized map is going internally to trace the function of we are in the default eager mode but model are by default tf.function wrapped.
If we want to maintain the critical section eager-compatible we need to automate the conditonal call on standard map_fn in the base class overload we have done (we are in eager mode).
More in general I think that this use of "layer as op" is still a little bit confusing:
https://github.com/keras-team/keras-cv/pull/122#discussion_r803214729
It seems like wrapping the entire layer in
tf.function
(even better if with jit_compile=True) also silences the warnings and provides decent performance in eager mode:@tf.function(jit_compile=True) def apply(x): return layer(x)
0.0015ms for
self.auto_vectorize=True
0.0022ms forself.auto_vectorize=False
0.0015ms for native vectorization.This issue can be closed from my end. If no further comments appear I will close this issue starting next week. Thanks for the help!
I don’t want to close it yet because I feel we can need to figure out how to effectively communicate this recommendation to users 🤔
I wonder if we should @tf.function our call methods by default to: a.) mute warnings b.) make performance consistent.
Vectorized map is going internally to trace the function of we are in the default eager mode but model are by default tf.function wrapped.
If we want to maintain the critical section eager-compatible we need to automate the conditonal call on standard map_fn in the base class overload we have done (we are in eager mode).
We could also @tf.function the base layers call method if needed. Or the augment batch method
We could also @tf.function the base layers call method if needed. Or the augment batch method
It really depends.. do you want to silently be in graph mode with some functions?
As the end user/developer doesn't control the vectorization in the API it is something that you are going to do behind the scene without any notification.
At least model.compile
still give the control to the end user for both eager and jit_compile (XLA) with its own args:
compile(
optimizer='rmsprop',
loss=None,
metrics=None,
loss_weights=None,
weighted_metrics=None,
run_eagerly=None,
steps_per_execution=None,
jit_compile=None,
**kwargs
)
/cc @mdanatg I suppose that the situation isn't evolved since Oct 2020 https://github.com/tensorflow/tensorflow/issues/43710#issuecomment-702430322. What do you think?
I think we now have better mechanisms to protect against excessive retracing. Is the error coming from a standard Keras layer, or is it a custom one?
As we don't have an object with tf.vectorized_map
it is hard to not retrace the function.
So I still believe that is better to automatically call map_fn
in eager mode and tf.vectorized_map
in graph mode.
Cause when the graph creation it is done implicitly by API design like in tf.data
it is documented explicitely:
https://www.tensorflow.org/api_docs/python/tf/data/Dataset?hl=en#map
map_func can accept as arguments and return any type of dataset element. Note that irrespective of the context in which map_func is defined (eager vs. graph), tf.data traces the function and executes it as a graph. To use Python code inside of the function you have a few options:
1) Rely on AutoGraph to convert Python code into an equivalent graph computation. The downside of this approach is that AutoGraph can convert some but not all Python code. 2) Use tf.py_function, which allows you to write arbitrary Python code but will generally result in worse performance than 1). For example:
Ok, after lots more digging and time I agree with you @bhack we should apply map_fn in eager, vectorized in graph. We can tackle this after @divyashreepathihalli migrates BaseImageAugmentationLayer to KerasCV. It will be easier to update when in KerasCV
Problem description
It seems like applying some layers that use
BaseImageAugmentationLayer
andself.auto_vectorize=True
, over batched input are causingtf.function
retracing:raises
Benchmarks
Running simple benchmarks confirms performance degradation with
tf.function
and batched input:Case 1: auto_vectorize=True
Without
tf.function
0.067 ms. With: 0.079 ms.Case 2: auto_vectorize=False
The issue doesn't pop up with non-batched input e. g.
(224, 224, 3)
or if one changesself.auto_vectorize=False
in the layer.Setting
self.auto_vectorize=False
will yield: Withouthtf.function
: 0.017 ms With: 0.013 ms.Case 3: override
_batch_augment
(if possible)In case of vectorized operations, the fastest option is still overriding
_batch_augment
to returnself._augment(inputs)
. This will yield: Withouttf.function
: 0.0059 ms With: 0.0016 ms