google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.13k stars 645 forks source link

Add support to jax2tf when using nnx.Module #4327

Closed noahzhy closed 3 weeks ago

noahzhy commented 3 weeks ago

Please add support to jax2tf for nnx.Module.

from flax import nnx
import tensorflow as tf
import jax.numpy as jnp
from jax.experimental import jax2tf
import optax
from flax.training import train_state

key = nnx.Rngs(0)
model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train()  # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)

input_shape = (1, 784)

class TrainState(train_state.TrainState):
    other_variables: nnx.State

state = TrainState.create(
    apply_fn=graphdef.apply,
    params=params,
    other_variables=other_variables,
    tx=optax.adam(1e-3),
)

def predict(input_img):
    return state.apply_fn(
        {'params': params, 'other_variables': other_variables},
        input_img
    )

tf_predict = tf.function(
    jax2tf.convert(predict, enable_xla=False),
    input_signature=[
        tf.TensorSpec(
            shape=input_shape,
            dtype=tf.float32,
            name='input_image')],
    autograph=False)

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [tf_predict.get_concrete_function()], tf_predict)

converter.allow_custom_ops = True
converter.experimental_new_converter = True

Error like:

AttributeError:
'dict' object has no attribute 'flat_state'
noahzhy commented 3 weeks ago

my bad


from flax import nnx
import tensorflow as tf
import jax.numpy as jnp
from jax.experimental import jax2tf
import optax
from flax.training import train_state

key = nnx.Rngs(0)
model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train()  # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)

input_shape = (1, 784)

class TrainState(train_state.TrainState):
    other_variables: nnx.State

state = TrainState.create(
    apply_fn=graphdef.apply,
    params=params,
    other_variables=other_variables,
    tx=optax.adam(1e-3),
)

def predict(input_img):
    return state.apply_fn(params, other_variables)(input_img)[0]

tf_predict = tf.function(
    jax2tf.convert(predict, enable_xla=False),
    input_signature=[
        tf.TensorSpec(
            shape=input_shape,
            dtype=tf.float32,
            name='input_image')],
    autograph=False)

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [tf_predict.get_concrete_function()],
    tf_predict
)

converter.allow_custom_ops = True
converter.experimental_new_converter = True
converter.experimental_new_quantizer = True

converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]

converter.inference_input_type = tf.float32
converter.inference_output_type = tf.float32
converter.optimizations = [tf.lite.Optimize.DEFAULT]

save_path = 'line.tflite'
with open('{}'.format(save_path), 'wb') as f:
    f.write(converter.convert())

print('\033[92m[done]\033[00m Model converted to tflite.')