jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.35k stars 2.78k forks source link

Get wrong structure when using jax2tf to convert nnx.module into tflite file #24497

Open noahzhy opened 4 days ago

noahzhy commented 4 days ago

Description

from flax import nnx
import tensorflow as tf
import jax.numpy as jnp
from jax.experimental import jax2tf
import optax
from flax.training import train_state

key = nnx.Rngs(0)

class LinearModel(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        self.linear0 = nnx.Linear(in_features, out_features, rngs=rngs)
        self.bn0 = nnx.BatchNorm(num_features=out_features, rngs=rngs)
        self.linear1 = nnx.Linear(in_features, out_features, rngs=rngs)
        self.act = nnx.log_softmax

    def __call__(self, x):
        x1 = self.linear0(x)
        x1 = self.bn0(x1)
        x2 = self.linear1(x)
        x2 = self.act(x2)
        return x1, x2

key = nnx.Rngs(0)
model = LinearModel(784, 10, key)
model.train()  # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)

input_shape = (1, 784)

class TrainState(train_state.TrainState):
    other_variables: nnx.State

state = TrainState.create(
    apply_fn=graphdef.apply,
    params=params,
    other_variables=other_variables,
    tx=optax.adam(1e-3),
)

def predict(input_img):
    return state.apply_fn(params, other_variables)(input_img)[0]

tf_predict = tf.function(
    jax2tf.convert(predict, enable_xla=False),
    input_signature=[
        tf.TensorSpec(
            shape=input_shape,
            dtype=tf.float32,
            name='input_image')],
    autograph=False)

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [tf_predict.get_concrete_function()],
    tf_predict
)

converter.allow_custom_ops = True
converter.experimental_new_converter = True
converter.experimental_new_quantizer = True

converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]

converter.inference_input_type = tf.float32
converter.inference_output_type = tf.float32
converter.optimizations = [tf.lite.Optimize.DEFAULT]

save_path = 'line.tflite'
with open('{}'.format(save_path), 'wb') as f:
    f.write(converter.convert())

print('\033[92m[done]\033[00m Model converted to tflite.')

But the converted tflite file structure is terrible. image

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  1.26.4
python: 3.11.10 (main, Sep  7 2024, 01:03:31) [Clang 16.0.0 (clang-1600.0.26.3)]
jax.devices (1 total, 1 local): [METAL(id=0)]
noahzhy commented 4 days ago

bn is okay, but log_softmax should not be so complicated