keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 331 forks source link

RandAugment performance issue #2122

Open sup3rgiu opened 1 year ago

sup3rgiu commented 1 year ago

Current Behavior:

Since v0.5.0, most preprocessing layers are subclasses of the new VectorizedBaseImageAugmentationLayer. However, RandAugment still relies on _batch_augment() of BaseImageAugmentationLayer, where each image of the input batch is processed indepentently with tf.map_fn. In this way, RandAugment do not take advantage of the vectorized implementation of the preprocessing layers.

Expected Behavior:

RandAugment should process the whole input batch. The fix is quite straightforward:

rand_augment = keras_cv.layers.RandAugment(value_range=(-1, 1), augmentations_per_image=3, magnitude=0.5)

rand_augment._batch_augment = rand_augment._augment  # (1)
rand_augment._random_choice.batchwise = True  # (2)

With (1), the _batch_augment() method points to _augment(), since there is no need to use tf.map_fn.

As for (2), under the hood _augment() of RandAugment calls _augment() of RandAugmentationPipeline, which in turn calls RandomChoice. Since RandomChoice inherits from BaseImageAugmentationLayer, if it's called with a batch of images as input, it will rely on _batch_augment. Again, we don't want to rely '_batch_augment' but instead on _augment. For RandomChoice, this can be achieved by setting the batchwise property to true.

Of course, these changes should be made in the inner implementation of keras_cv.layers.RandAugment, but this shows the simplicity of the fix.

Benchmark:

Execution time over 10 CIFAR-10 batches of size 128:

Original implementation: 0:01:16.377174 (1 minute, 16 seconds)
Fixed implementation: 0:00:03.288814 (3 seconds)

Colab:

https://colab.research.google.com/drive/1ExgceI_WwNssbgiHCKcSEKVVEAgejEL0?usp=sharing

Version:

0.6.4

Anything else:

The speed-up can be further increased by implementing RandomColorDegeneration and Equalization preprocessing layers in a vectorized manner, since out of the 8 possible random augmentations (AutoContrast, Equalization, Solarization, RandomColorDegeneration, RandomContrast, RandomBrightness, RandomShear, RandomTranslation), these are the only two that are currently not vectorized. See https://github.com/keras-team/keras-cv/issues/2120 for a possible vectorized implementation.

LukeWood commented 1 year ago

Vectorizing RandAugment would be an awesome win for everyone using it. Unfortunately, I don't think it is quite as simple as it has been laid out here. In the proposed fix you are applying the SAME transformations to each image in a batch.

This makes it so the samples are correlated, and batches may not have sufficient random sampling for optimal convergence. If we truly want to make RandAugment vectorized, we'll need to use a ton of tf.where() and tf.rand() conditions (probably one per-layer) in order to create a fully vectorized implementation. It is pretty nontrivial sadly.

Note that while the same factors are unique per image, selected per transformation, the same transformation set is being picked. So note in my example, it looks all images are being sheared and having their brightness adjusted. Just the factors sampled are different

download

Compare this to the first, where each image has distinct transformations samples: download-1

NB source:

https://colab.research.google.com/drive/1qmFBk0OWl7LnxiYlhKAmV2hy0HYyyDtR?usp=sharing

Additional notes (just an FYI):