keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Add `ops.nn.moments` and speed-up normalization layers #866

Closed james77777778 closed 11 months ago

james77777778 commented 12 months ago

EDITED: This PR adds ops.nn.moments and improves some normalization layers by fast mean and variance computation. This PR also addresses the overflow & underflow issue when the input tensor is float16.

There are two approaches to compute variance:

backend layer manual implementation (before this PR) ops.nn.moments (fast but unstable, this PR) ops.nn.moments (stable but slower)
tensorflow BatchNormalization 58ms 58ms 69ms
jax BatchNormalization 63ms 63ms 75ms
torch BatchNormalization 73ms 72ms 74ms
tensorflow GroupNormalization 89ms 61ms 73ms
jax GroupNormalization 96ms 72ms 83ms
torch GroupNormalization 72ms 74ms 76ms
tensorflow LayerNormalization 52ms 47ms 48ms
jax LayerNormalization 68ms 59ms 60ms
torch LayerNormalization 88ms 90ms 91ms

References:

codecov[bot] commented 12 months ago

Codecov Report

Patch coverage: 98.94% and project coverage change: +0.06% :tada:

Comparison is base (bb21710) 76.49% compared to head (8f06862) 76.56%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #866 +/- ## ========================================== + Coverage 76.49% 76.56% +0.06% ========================================== Files 329 329 Lines 31334 31422 +88 Branches 6100 6113 +13 ========================================== + Hits 23970 24057 +87 - Misses 5785 5786 +1 Partials 1579 1579 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/866/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/866/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `76.46% <98.94%> (+0.06%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/ops/nn.py](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHMvbm4ucHk=) | `89.69% <93.33%> (+0.13%)` | :arrow_up: | | [keras\_core/backend/jax/nn.py](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC9ubi5weQ==) | `94.00% <100.00%> (+0.59%)` | :arrow_up: | | [keras\_core/backend/numpy/nn.py](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL251bXB5L25uLnB5) | `93.10% <100.00%> (+0.71%)` | :arrow_up: | | [keras\_core/backend/tensorflow/nn.py](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RlbnNvcmZsb3cvbm4ucHk=) | `82.10% <100.00%> (+1.20%)` | :arrow_up: | | [keras\_core/backend/torch/nn.py](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RvcmNoL25uLnB5) | `91.77% <100.00%> (+0.54%)` | :arrow_up: | | [...s\_core/layers/normalization/batch\_normalization.py](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvbm9ybWFsaXphdGlvbi9iYXRjaF9ub3JtYWxpemF0aW9uLnB5) | `100.00% <100.00%> (ø)` | | | [...s\_core/layers/normalization/group\_normalization.py](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvbm9ybWFsaXphdGlvbi9ncm91cF9ub3JtYWxpemF0aW9uLnB5) | `89.01% <100.00%> (-0.12%)` | :arrow_down: | | [...s\_core/layers/normalization/layer\_normalization.py](https://app.codecov.io/gh/keras-team/keras-core/pull/866?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvbm9ybWFsaXphdGlvbi9sYXllcl9ub3JtYWxpemF0aW9uLnB5) | `97.40% <100.00%> (+0.03%)` | :arrow_up: |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

fchollet commented 12 months ago

Thanks for the PR, this is great! What performance changes did you observe in e.g. BatchNormalization when using moments in different backends compared to the existing manual implementation?

james77777778 commented 12 months ago

Thanks for the PR, this is great! What performance changes did you observe in e.g. BatchNormalization when using moments in different backends compared to the existing manual implementation?

EDITED: Please see the newest comment and let me know which implementation should be taken.

ORIGINAL: @fchollet

I have reordered the operations in ops.nn.moments to get a significant speed-up. The key is to downsize the tensor before performing element-wise operations (such as jnp.subtract)

# actually, tf.nn.moments could be faster...

# original
variance = jnp.mean(jnp.square(x - jax.lax.stop_gradient(mean)), axis=axes, keepdims=True)

# faster version
variance = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - jnp.square(
    jax.lax.stop_gradient(mean)
)

I have also updated the Normalization layers. I observed better performance in TF & JAX but torch showed similar performance.

Benchmark script:

from keras_core import layers
from keras_core import models
from keras_core import ops

x_train = ops.random.uniform(shape=(1024, 224, 224, 3))
y_train = ops.random.uniform(shape=(1024, 224, 224, 3))

# layers.BatchNormalization
# layers.GroupNormalization
# layers.LayerNormalization
normalization_cls = layers.BatchNormalization
normalization_args = {}
if normalization_cls is layers.GroupNormalization:
    normalization_args = {"groups": 3}

model = models.Sequential(
    [
        layers.InputLayer(shape=(224, 224, 3)),
        normalization_cls(**normalization_args),
        normalization_cls(**normalization_args),
        normalization_cls(**normalization_args),
    ]
)
model.compile(loss="mse", optimizer="adam")
model.fit(x_train, y_train, batch_size=128, epochs=3)

Results (with 1080 8gb card):

backend layer current implementation using ops.nn.moments notes
tensorflow BatchNormalization 58ms 58ms fair
jax BatchNormalization 63ms 63ms fair
torch BatchNormalization 73ms 72ms fair
tensorflow GroupNormalization 89ms 61ms ⬇️
jax GroupNormalization 96ms 72ms ⬇️
torch GroupNormalization 72ms 74ms fair
tensorflow LayerNormalization 52ms 47ms ⬇️
jax LayerNormalization 68ms 59ms ⬇️
torch LayerNormalization 88ms 90ms fair
james77777778 commented 12 months ago

Hi @fchollet I have a question about the choice of implementation:

In TensorFlow

tf.nn.moments uses a slower but more numerically stable version to compute variance

https://github.com/tensorflow/tensorflow/blob/3d1802023778a164d35c79536990b35b701e8018/tensorflow/python/ops/nn_impl.py#L1264C5-L1268C25

        variance = math_ops.reduce_mean(
        math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
        axes,
        keepdims=True,
        name="variance")

In Keras Core and Flax

Keras Core and Flax (the NN library for JAX) use a faster but less stable version for computing variance

https://github.com/google/flax/blob/ca3ea06f78834137dfb49dc6c1a0c26fb962003a/flax/linen/normalization.py#L108-L120

    # use_fast_variance=True by default in Flax
    if use_fast_variance:
      mu, mu2 = maybe_distributed_mean(x, _abs_sq(x))
      # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
      # to floating point round-off errors.
      var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
    else:
      mu = maybe_distributed_mean(x)
      var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)))

Torch

I just checked the cpu version and torch should use the same approach as tensorflow

https://github.com/pytorch/pytorch/blob/48e6ffbe308e915b67c5b4f9532f794d6706c903/aten/src/ATen/native/cpu/batch_norm_kernel.cpp#L200-L209

Performance Comparison (TensorFlow)

backend layer current implementation fast variance computation tf.nn.moments
tensorflow BatchNormalization 58ms 58ms 69ms
tensorflow GroupNormalization 89ms 61ms 73ms
tensorflow LayerNormalization 52ms 47ms 48ms

Which one should we take?

fchollet commented 11 months ago

Thanks for the analysis -- what does "current implementation" refer to? The implementation in this PR?

james77777778 commented 11 months ago

Thanks for the analysis -- what does "current implementation" refer to? The implementation in this PR?

Sorry for the confusion. I have updated the table as follows:

backend layer manual implementation (before this PR) ops.nn.moments (fast but unstable variance computation, this PR's implementation) ops.nn.moments (stable but slower variance computation)
tensorflow BatchNormalization 58ms 58ms 69ms
jax BatchNormalization 63ms 63ms 75ms
torch BatchNormalization 73ms 72ms 74ms
tensorflow GroupNormalization 89ms 61ms 73ms
jax GroupNormalization 96ms 72ms 83ms
torch GroupNormalization 72ms 74ms 76ms
tensorflow LayerNormalization 52ms 47ms 48ms
jax LayerNormalization 68ms 59ms 60ms
torch LayerNormalization 88ms 90ms 91ms

References:

The question remains: Should we adopt the fast but unstable variance computation or the stable but slower version?

fchollet commented 11 months ago

My take is that until we see reports of users running into stability issues, then the fast implementation should be fine. The fact that Flax defaults to it is evidence that there's little issue. I did a quick search within the google codebase and found only a couple of usages of use_fast_variance=False, among thousands of usages of Flax normalization layers (they aren't commented, so unclear why they went with False). So it seems that in practice the problem doesn't really surface.

fchollet commented 11 months ago

The code looks good! Do you want to include the normalization layer changes in this PR, or merge this PR first and then create another one?

james77777778 commented 11 months ago

My take is that until we see reports of users running into stability issues, then the fast implementation should be fine. The fact that Flax defaults to it is evidence that there's little issue. I did a quick search within the google codebase and found only a couple of usages of use_fast_variance=False, among thousands of usages of Flax normalization layers (they aren't commented, so unclear why they went with False). So it seems that in practice the problem doesn't really surface.

Thanks for the valuable insights about the usage of Flax. After a lot of searching, I can only find this comment defending the numerically stable computation of variance (without evidence?) https://github.com/tensorflow/tensorflow/pull/4198#issuecomment-244824858

The code looks good! Do you want to include the normalization layer changes in this PR, or merge this PR first and then create another one?

I think the changes are already in this PR. Please let me know if I missed anything.

james77777778 commented 11 months ago

Hi @fchollet This PR should be ready. We now have fast mean & variance computation using ops.nn.moments, and it is applied to BatchNormalization, GroupNormalization and LayerNormalization to achieve some speed-ups.

Addtionally, this PR addresses the overflow and underflow issue that occur when the input is float16. (I encountered this before when using GroupNormalization with mixed_float16) Credits to @fsx950223 (https://github.com/tensorflow/tensorflow/pull/52217)