google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.69k stars 191 forks source link

Conversion to TFLite failed #1047

Closed JuanFMontesinos closed 2 months ago

JuanFMontesinos commented 2 months ago

Hi, I was trying to conver a model to tflite for deploying. I'm doing this in a two-stage process, first converting jax to TF's SavedModel and then to tflite as follows:

def export_model(model_checkpoint: Path, output_dir: Path, prep_fn=None, post_fn=None):
    """
    Export a jax model to a tensorflow SavedModel format.

    Note: https://www.tensorflow.org/guide/saved_model
    """
    model = LRU.from_serialization(model_checkpoint.absolute(), batch_size=1)
    print(f"[*] Model loaded from checkpoint{model_checkpoint}")
    print("Model configuration:")
    pprint(model.model_cfg)

    server_config = [
        ServingConfig(
            "serving_default",
            input_signature=[
                tf.TensorSpec(
                    shape=[
                        None,
                        model.model_cfg["inference"]["seq_length"],
                        model.model_cfg["inference"]["input_dims"],
                    ],
                    dtype=tf.float32,
                ),
            ],
            tf_postprocessor=prep_fn,
            tf_preprocessor=post_fn,
        )
    ]

    jax_module = JaxModule(model.state, model.model.apply, input_polymorphic_shape="b, ...")
    export_mgr = ExportManager(jax_module, server_config)
    export_mgr.save(output_dir.as_posix())
    with open(output_dir / "aiwizard-config.json", "w") as f:
        json.dump(model.model_cfg, f)
    return export_mgr, model

def export_tf2tflite(saved_model_dir: Path, tflite_dir: Path):
    """
    Convert a tensorflow SavedModel to a tflite model.
    """
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir.as_posix())
    converter.experimental_enable_resource_variables = True
    tflite_model = converter.convert()
    with open(tflite_dir / "model.tflite", "wb") as f:
        f.write(tflite_model)

When converting the model from tf to tflite, I get the following error:

tensorflow.lite.python.convert_phase.ConverterError: Variable constant folding is failed. Please consider using enabling `experimental_enable_resource_variables` flag in the TFLite converter object. For example, converter.experimental_enable_resource_variables = True/home/jmt/Projects/panama/panama-external/src/jax/lru/lru/core/model.py:91:23: error: failed to legalize operation 'tfl.transpose' that was explicitly marked illegal
        Bu_elements = (B_norm @ inputs.T).T
                      ^
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from

that comes from this forward function

        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
        C = self.C_re + 1j * self.C_im

        Lambda_elements = jnp.repeat(diag_lambda[None, ...], inputs.shape[0], axis=0)
        Bu_elements = (B_norm @ inputs.T).T
        # Compute hidden states
        _, hidden_states = parallel_scan(binary_operator_diag, (Lambda_elements, Bu_elements))
        # Use them to compute the output of the module
        outputs = jax.vmap(lambda x, u: (C @ x).real + self.D * u)(hidden_states, inputs)

        return outputs

Do you have any clue why a simple transposition would fail?

Thanks Juan

vroulet commented 2 months ago

Hello @JuanFMontesinos, I don't see how this is related to optax. Maybe you meant to post an issue in JAX or FLAX channels?

JuanFMontesinos commented 2 months ago

Hi @vroulet Indeed, my bad. The exporter is Orbax's and I think my brain just switched to optax. Thanks for noticing!