jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

jax.numpy.mean ops convert to tflite is slowly #20672

Closed noahzhy closed 7 months ago

noahzhy commented 7 months ago

Description

Convert the same model into tflite file, but the Mean layer was converted to two different node.

import tensorflow as tf
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf

tf_model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(28, 28, 64)),
    tf.keras.layers.Lambda(lambda x: tf.reduce_mean(
        x, axis=(1, 2), keepdims=True)),
])

converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
tflite_model = converter.convert()
open("tf_mean.tflite", "wb").write(tflite_model)

x_input = jnp.zeros((1, 28, 28, 64))

fn = tf.function(
    jax2tf.convert(
        lambda x: jnp.mean(x, axis=(1, 2), keepdims=True),
        enable_xla=False),
    input_signature=[
        tf.TensorSpec(
            shape=list(x_input.shape),
            name='inputs')],
    autograph=False)

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [fn.get_concrete_function(x_input)], fn
)

tflite_model = converter.convert()
open("jax_mean.tflite", "wb").write(tflite_model)

20240410_061700

Certainly, this will affect the inference speed of the TFLite model.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.11.7 (main, Dec 4 2023, 18:10:11) [Clang 15.0.0 (clang-1500.1.0.2.5)] jax.devices (1 total, 1 local): [METAL(id=0)]

mattjj commented 7 months ago

Thanks for raising this!

Indeed JAX has no primitive Mean operation. That happens quite often; instead of having lots of distinct kernels, JAX relies on a compiler to fuse primitives into performant kernels. That means that if instead of a compiler we have an interpreter downstream of a JAX function, we might get inefficiencies like this. (It's an open question whether we can have our cake and eat it too, by e.g. staging out a program representation that provides both the Mean function and its implementation in terms of primitives, so that downstream interpreters can use their own direct Mean implementations rather than having to interpret its implementation.)

I think this is working-as-intended, but I'm going to tag @gnecula and @superbobry to check. Any thoughts to add?