Closed noahzhy closed 3 weeks ago
my bad
from flax import nnx
import tensorflow as tf
import jax.numpy as jnp
from jax.experimental import jax2tf
import optax
from flax.training import train_state
key = nnx.Rngs(0)
model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
input_shape = (1, 784)
class TrainState(train_state.TrainState):
other_variables: nnx.State
state = TrainState.create(
apply_fn=graphdef.apply,
params=params,
other_variables=other_variables,
tx=optax.adam(1e-3),
)
def predict(input_img):
return state.apply_fn(params, other_variables)(input_img)[0]
tf_predict = tf.function(
jax2tf.convert(predict, enable_xla=False),
input_signature=[
tf.TensorSpec(
shape=input_shape,
dtype=tf.float32,
name='input_image')],
autograph=False)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_predict.get_concrete_function()],
tf_predict
)
converter.allow_custom_ops = True
converter.experimental_new_converter = True
converter.experimental_new_quantizer = True
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
converter.inference_input_type = tf.float32
converter.inference_output_type = tf.float32
converter.optimizations = [tf.lite.Optimize.DEFAULT]
save_path = 'line.tflite'
with open('{}'.format(save_path), 'wb') as f:
f.write(converter.convert())
print('\033[92m[done]\033[00m Model converted to tflite.')
Please add support to jax2tf for nnx.Module.
Error like: