keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

BUG: `RandomContrast` slows down training 6-fold #581

Closed DavidLandup0 closed 1 year ago

DavidLandup0 commented 2 years ago

Augmentation will obviously slow down training, but it shouldn't be a 6-fold slowdown. This happens with the RandomContrast layer, which makes the training per epoch grow from ~100s to ~600s.

I'd share a Colab notebook but there seems to be an issue with KerasCV on Colab on importing, so here are the steps to reproduce:

import keras_cv
print(f'KerasCV version {keras_cv.__version__}')
import tensorflow as tf
print(f'TF version {tf.__version__}')
from tensorflow import keras
print(f'Keras version {keras.__version__}')
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
aug = keras.Sequential([
    keras_cv.layers.RandomContrast(factor=0.2)
])
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    aug,
    keras.applications.EfficientNetV2B0(weights=None, include_top=False),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(10, activation='softmax')

])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

model.fit(x_train, y_train)
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.applications.EfficientNetV2B0(weights=None, include_top=False),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(10, activation='softmax')

])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

model.fit(x_train, y_train)
bhack commented 2 years ago

I suppose that we are in the same case as: https://github.com/tensorflow/tensorflow/issues/56242

The root cause is https://github.com/keras-team/keras-cv/issues/291

WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformFullIntV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomGetKeyCounter cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting AdjustContrastv2 cause Input "contrast_factor" of op 'AdjustContrastv2' expected to be loop invariant.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformFullIntV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomGetKeyCounter cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting AdjustContrastv2 cause Input "contrast_factor" of op 'AdjustContrastv2' expected to be loop invariant.
LukeWood commented 2 years ago

I will pin this issue. This is a significant slowdown, maybe we need to consider manually vectorizing - which would be very unfortunate.

bhack commented 2 years ago

I will pin this issue. This is a significant slowdown, maybe we need to consider manually vectorizing - which would be very unfortunate.

I think that it will more useful to brainstorm a solution for the old and more general problem at https://github.com/keras-team/keras-cv/issues/291 then pinning every single issue as this is already the 2nd one (https://github.com/tensorflow/tensorflow/issues/56242).

LukeWood commented 2 years ago

Would you mind explaining what you have in mind for a "solution for the old and more general problem" ? I'm not sure how we would solve this in the general case.

bhack commented 2 years ago

As the superset of this issue is https://github.com/keras-team/keras-cv/issues/291 and the root cause is our choice of vectorized_map and within the batch randomization as we discussed in many early tickets on our augmentation design we could:

bhack commented 2 years ago

About my last comment /cc @ishark @wangpengmit

MarkDaoust commented 2 years ago

maybe we need to consider manually vectorizing - which would be very unfortunate.

https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast

adjust_contrast can take a stack of images, but only a scalar contrast factor, not a list. That's too bad, especially for such a simple function.

bhack commented 2 years ago

adjust_contrast can take a stack of images, but only a scalar contrast factor, not a list. That's too bad, especially for such a simple function.

Isn't this namespace orphan https://github.com/keras-team/keras-cv/issues/74#issuecomment-1035597742?

Also, I never heard that we want to contribute/coordinate with the tf.image.* API. @MarkDaoust see the full thread in the early weeks of the Keras-cv repo at https://github.com/keras-team/keras-cv/pull/122#discussion_r803936244

davidanoel commented 2 years ago

The same is happening for RandAugment. Without RandAugment each epoch takes around 35s on my machine. With RandAugment it takes about 2mins 25 seconds. Any resolution on the roadmap?

bhack commented 2 years ago

@atlanticstarr1 We have started a thread at https://github.com/tensorflow/tensorflow/issues/55639#issuecomment-1233073292

You could try to ping there.

bhack commented 2 years ago

Just to confirm my hypothesis about the original @DavidLandup0's random contrast example https://github.com/keras-team/keras-cv/issues/581#issue-1303119198 at the origin of this ticket.

I've tested it with TF 2.10.0 on Colab and the overhead of "within the batch randomization/vecorized_map fallback" is huge.

Just to quickly workaround the effect, using a constant factor (e.g. 10) at https://github.com/keras-team/keras/blob/v2.10.0/keras/layers/preprocessing/image_preprocessing.py#L1675 we have 46ms/step

With the official "within the batch/vectorized_map fallback" we have 152ms/step.

So we had a performance drop of 3.30x with the default 32 batch size and this gap is going to be increased for sure with larger batches/input sizes.

Please check it yourself with this Colab so we are on the same page without waiting for the private GPU CI auth:

https://colab.research.google.com/gist/bhack/355dd2a56c734bb04f7025e15fb2a53d/randomcontrast_benchmark.ipynb#scrollTo=b2L6G_XYPzOz

bhack commented 2 years ago

/cc @martin-gorner

bhack commented 1 year ago

As we are going to potentially introduce this issue also in the new 3d preprocessing API with https://github.com/keras-team/keras-cv/pull/986 I want to clarify this example as it is similar to other KLP cases.

Assuming that we have 100% coverage of the pfor converters (which we obviously don't actually have) let's see what happen in this case.

This is the AdjustContrastv2 converter in TF that it is required by us as we are calling tf.image.adjust_contrast API in this layer implementation: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/parallel_for/pfor.py#L1724

@RegisterPFor("AdjustContrastv2")
def _convert_adjust_contrastv2(pfor_input):
  images = pfor_input.stacked_input(0)
  contrast_factor = pfor_input.unstacked_input(1)
  return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True)

As you can see contrast_factor is handled as an unstacked_input so it requires that the contrast_factor is loop invariant.

Instead, with our within the batch augmentation policy, we want to have a different random factor for each single image in the batch.

vectorized_map/pfor cannot stack contrast_factor in the converter cause the tf.image.adjust_contrast API signature, the underline CPU and GPU ops and kernels, are designed and implemented to accept stacked/batched images but just a scalar contrast_factor float multiplier for adjusting contrast.

tf.image.adjust_contrast( images, contrast_factor )

So as you can see it is not strictly an issue related to the vectorized_map that instead it is eventually mainly impacted by the missing converters coverage for some ops that we are using (or that we will want to use).

The main problem is that we want adopt a within the batch policy and this is going to have many performance overhead independently from the use of vectorized_map or map_fn if the underline TF ops don't support the stacked args that we want randomize differently for every single batch element.

So I think that we have two main options here (other the extending the converters coverage):

The only mentioned paper in this repo was https://github.com/keras-team/keras-cv/issues/372#issuecomment-1110827256 that at least will let to apply the same randomized factor on sub-batches partially limiting the bad impact on the performance of the current "full within" batch policy.

tanzhenyu commented 1 year ago

As we are going to potentially introduce this issue also in the new 3d preprocessing API with #986 I want to clarify this example as it is similar to other KLP cases.

Assuming that we have 100% coverage of the pfor converters (which we obviously don't actually have) let's see what happen in this case.

This is the AdjustContrastv2 converter in TF that it is required by us as we are calling tf.image.adjust_contrast API in this layer implementation: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/parallel_for/pfor.py#L1724

@RegisterPFor("AdjustContrastv2")
def _convert_adjust_contrastv2(pfor_input):
  images = pfor_input.stacked_input(0)
  contrast_factor = pfor_input.unstacked_input(1)
  return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True)

As you can see contrast_factor is handled as an unstacked_input so it requires that the contrast_factor is loop invariant.

Instead, with our within the batch augmentation policy, we want to have a different random factor for each single image in the batch.

vectorized_map/pfor cannot stack contrast_factor in the converter cause the tf.image.adjust_contrast API signature, the underline CPU and GPU ops and kernels, are designed and implemented to accept stacked/batched images but just a scalar contrast_factor float multiplier for adjusting contrast.

tf.image.adjust_contrast( images, contrast_factor )

So as you can see it is not strictly an issue related to the vectorized_map that instead it is eventually mainly impacted by the missing converters coverage for some ops that we are using (or that we will want to use).

The main problem is that we want adopt a within the batch policy and this is going to have many performance overhead independently from the use of vectorized_map or map_fn if the underline TF ops don't support the stacked args that we want randomize differently for every single batch element.

So I think that we have two main options here (other the extending the converters coverage):

  • Rewrite all the TF ops API and the related CPU/GPU ops/kernels to support the stacked argument the we want to randomize within the batch
  • Re-evaluate the within the batch policy on the performance overhead it introduces vs the faster(?) convergence rate. I have never seen here some experimental data to sustain the gain we have with this policy on the FLOPS/epochs convergence ratio.

The only mentioned paper in this repo was #372 (comment) that at least will let to apply the same randomized factor on sub-batches partially limiting the bad impact on the performance of the current "full within" batch policy.

I'm not sure I really understand the argument here. To me it seems the main issue is contrast_factor cannot be a stacked input here? Is this more a TF problem than a KerasCV problem?

bhack commented 1 year ago

It is contrast_factor in this example but it is just a case of a more general issue. Just in April the list in Keras-CV was already quite long: https://github.com/tensorflow/tensorflow/issues/55639#issuecomment-1112227581

And this will happen every time you will want to randomize an arg, within the batch, of an Op where that arg is scalar by the TF API design.

Then we could tell that it is a TF issue and we don't care but who have really the resources to change all these API/ops/kernels in TF for your new design needs?

At the same time Keras is no more a multi backed library, users are impacted directly by the performance issues we have with this design and they cannot switch to an alternative "backend". So the separation logic between Keras issues and TF issues in many cases doesn't make sense by an user point of view.

So a limit of your within the batch randomization policy it is going to impact directly the Keras-cv user base and the claim that it is a TF (team?) issue it does not solve anyone's problems.

At the same time we don't have (public?) experimental data about the accuracy/epochs gain related to the choice to randomize an augmentation over each batch element that it is the design choice that created all these performance overhead considering the TF API design we currently have on many Ops.

jasonrichdarmawan commented 1 year ago

tensorflow-macos==2.10.0 and 2.11.0 have this issue.

I use RandomRotation.

bhack commented 1 year ago

@kidfrom We have already extensively discussed this in many tickets. The last thread was at https://github.com/tensorflow/tensorflow/issues/55639#issuecomment-1310229476

bhack commented 1 year ago

To track the perfomrance thread we had recently on the RepeatedAugmentation layer https://github.com/keras-team/keras-cv/pull/1293#discussion_r1083533391

If we can have in the batch the same image with the different augmentations why we cannot have different image with the same augmentation param in the batch?

/ cc @LukeWood

jbischof commented 1 year ago

Closing due to staleness.