keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Use `ops.rsqrt`, improve normalization layers and enable ops fusion in tflite #892

Closed james77777778 closed 10 months ago

james77777778 commented 10 months ago

Fixes #824

This PR accomplishes the following:

  1. adding support for rsqrt in numpy backend (using jax's impl)
  2. replacing 1 / ops.sqrt(x) with ops.rsqrt for improved speed
  3. reordering the ops in normalization layers to unify the implementation and match the expression of tf.nn.batch_normalization link
  4. Ensuring 100% unit test coverage for all normalization layers

After completing 3, tflite recognizes the pattern of CONV+BN+ReLU, and the ops are fused successfully.

standalone MobileNetV3 export script ```python import tensorflow as tf from keras_core.applications.mobilenet_v3 import MobileNetV3Small keras_core_model = MobileNetV3Small( input_shape=(224, 224, 3), minimalistic=True ) tf_callable = tf.function( keras_core_model.call, input_signature=[tf.TensorSpec((1, 224, 224, 3), tf.float32)], autograph=True, jit_compile=True, ) tf_concrete_function = tf_callable.get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_concrete_function], tf_callable ) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() with open("model.tflite", "wb") as f: f.write(tflite_model) ```

The visualization from netron: (before this PR vs. after this PR) 圖片

benchmark script ```python from keras_core import layers from keras_core import mixed_precision from keras_core import models from keras_core import ops # "float32" # "mixed_float16" # "mixed_bfloat16" dtype_policy = "float32" mixed_precision.set_dtype_policy(dtype_policy) x_train = ops.random.uniform(shape=(512, 64, 64, 64)) y_train = ops.random.uniform(shape=(512, 64, 64, 64)) # layers.BatchNormalization # layers.GroupNormalization # layers.LayerNormalization normalization_cls = layers.LayerNormalization normalization_args = {} if normalization_cls is layers.GroupNormalization: normalization_args = {"groups": -1} model = models.Sequential( [ layers.InputLayer(shape=(64, 64, 64)), normalization_cls(**normalization_args), normalization_cls(**normalization_args), normalization_cls(**normalization_args), ] ) model.compile(loss="mse", optimizer="adam") model.fit(x_train, y_train, batch_size=128, epochs=3) ```
And the improvement: backend layer before this PR after this PR
tensorflow BatchNormalization 48ms/step 46ms/step
jax BatchNormalization 49ms/step 48ms/step
torch BatchNormalization 127ms/step 127ms/step
tensorflow GroupNormalization 50ms/step 49ms/step
jax GroupNormalization 51ms/step 50ms/step
torch GroupNormalization 129ms/step 129ms/step
tensorflow LayerNormalization 54ms/step 53ms/step
jax LayerNormalization 55ms/step 54ms/step
torch LayerNormalization 165ms/step 122ms/step
codecov[bot] commented 10 months ago

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.25% :tada:

Comparison is base (94b5361) 76.56% compared to head (10e4a03) 76.82%. Report is 4 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #892 +/- ## ========================================== + Coverage 76.56% 76.82% +0.25% ========================================== Files 329 329 Lines 31429 31426 -3 Branches 6114 6111 -3 ========================================== + Hits 24064 24143 +79 + Misses 5786 5719 -67 + Partials 1579 1564 -15 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/892/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/892/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `76.72% <100.00%> (+0.25%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/892?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/backend/numpy/math.py](https://app.codecov.io/gh/keras-team/keras-core/pull/892?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL251bXB5L21hdGgucHk=) | `82.43% <100.00%> (+0.24%)` | :arrow_up: | | [...s\_core/layers/normalization/batch\_normalization.py](https://app.codecov.io/gh/keras-team/keras-core/pull/892?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvbm9ybWFsaXphdGlvbi9iYXRjaF9ub3JtYWxpemF0aW9uLnB5) | `100.00% <100.00%> (ø)` | | | [...s\_core/layers/normalization/group\_normalization.py](https://app.codecov.io/gh/keras-team/keras-core/pull/892?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvbm9ybWFsaXphdGlvbi9ncm91cF9ub3JtYWxpemF0aW9uLnB5) | `97.64% <100.00%> (+8.63%)` | :arrow_up: | | [...s\_core/layers/normalization/layer\_normalization.py](https://app.codecov.io/gh/keras-team/keras-core/pull/892?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvbm9ybWFsaXphdGlvbi9sYXllcl9ub3JtYWxpemF0aW9uLnB5) | `100.00% <100.00%> (+2.59%)` | :arrow_up: | | [...as\_core/layers/normalization/unit\_normalization.py](https://app.codecov.io/gh/keras-team/keras-core/pull/892?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvbm9ybWFsaXphdGlvbi91bml0X25vcm1hbGl6YXRpb24ucHk=) | `100.00% <100.00%> (+7.69%)` | :arrow_up: | ... and [11 files with indirect coverage changes](https://app.codecov.io/gh/keras-team/keras-core/pull/892/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.