keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.47k forks source link

How to Convert Keras Model‘s Inference Method (with JAX Backend) to Flax Training State for Using Flax to Predict #20255

Closed KaiyueDuan closed 1 month ago

KaiyueDuan commented 1 month ago

I am using TensorFlow as the backend to train a DNN model. After training, I successfully converted the inference method of the Keras model to a JAX function and created a Flax training_state to perform inference using Flax. This workflow is working well. Here is my notebook.

However, when I switch to using JAX as the backend, I am unsure how to convert the inference method of the Keras model into a JAX function. Furthermore, I am also unclear about the steps needed to create a Flax training_state afterwards.

Could anyone provide guidance on how to achieve this? Any help would be greatly appreciated!

fchollet commented 1 month ago

You can just use model.stateless_call(trainable_variables, non_trainable_variables, *args) (args are the model/layer's call arguments). This is a pure JAX function.

fchollet commented 1 month ago

Check out this example. https://github.com/keras-team/keras/blob/master/examples/demo_custom_jax_workflow.py

KaiyueDuan commented 1 month ago

Thank you, François! Your suggestion worked perfectly for my program. I successfully computed the Jacobian matrix using model.stateless_call(). Here is the code I used:

def func_to_diff(x):
    x = x[None, :]
    return model.stateless_call(trainable_variables, non_trainable_variables, x)[0]

def jac_fwd_lambda(single_input):
    return jax.jacfwd(func_to_diff)(single_input)

jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)
KaiyueDuan commented 1 month ago

Hi, I encountered a new issue. I noticed that in sample code https://github.com/keras-team/keras/blob/master/examples/demo_custom_jax_workflow.py, the model is a global variable.

However, in my code, I need to pass my model as a function parameter; here is my code:

from functools import partial
key = jax.random.PRNGKey(42)
tf2jax.update_config('strict_shape_check', False)

@jax.jit
def _jax_predict_tf_single(model_state, single_point):
    return model_state.apply_fn( model_state.params, single_point)[0]

@partial(jax.jit, static_argnums=0)
def predict_jax_single(my_model,single_point):
    return  my_model.stateless_call(my_model.trainable_variables,my_model.non_trainable_variables,single_point[None, :])[0].squeeze(axis=0)

def f_jacfwd(predict_single,my_model,input_data):
  # mode 0: flax; mode 1: tf; mode 2: jax; otherwise: error

    def jac_fwd_lambda(single_input):
        return jax.jacfwd(predict_single)(my_model,single_input)

    return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)

input_data = np.ones((5,3))
ret = f_jacfwd(predict_jax_single, model, input_data)

The code throws me an error

 ---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in <cell line: 24>()
     22 
     23 input_data = np.ones((5,3))
---> 24 ret = f_jacfwd(predict_jax_single, model, input_data)

2 frames
[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in f_jacfwd(predict_single, my_model, input_data)
     19         return jax.jacfwd(predict_single)(my_model,single_input)
     20 
---> 21     return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)
     22 
     23 input_data = np.ones((5,3))

    [... skipping hidden 3 frame]

[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in jac_fwd_lambda(single_input)
     17 
     18     def jac_fwd_lambda(single_input):
---> 19         return jax.jacfwd(predict_single)(my_model,single_input)
     20 
     21     return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)

    [... skipping hidden 4 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in check_arg(arg)
    279 def check_arg(arg: Any):
    280   if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
--> 281     raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
    282                     "JAX type.")
    283 

TypeError: Argument '<Functional name=functional, built=True>' of type <class 'keras.src.models.functional.Functional'> is not a valid JAX type.

And I found that the type of model.stateless_call is

<class 'method'>
{'__wrapped__': <function Layer.stateless_call at 0x795fc6fdfa30>}

, which seems not to be a Jax function.

So I am wondering how to pass the model as a parameter. Here is my notebook https://colab.research.google.com/drive/1nV8oIn4TzgmtcAk1xFaaFg4EnxLN9n4c?usp=sharing

Many thanks!

KaiyueDuan commented 1 month ago

Update: I solve this issue by defining a inner function model_call, which uses the local variable my_model:

@partial(jax.jit, static_argnums=(0,1))
def f_jacfwd(predict_single,my_model,input_data):
    def jac_fwd_lambda(single_input):
        if "jax_single" in predict_single.__name__:
            def model_call(input_val):
                result = my_model.stateless_call(my_model.trainable_variables, my_model.non_trainable_variables, input_val[None, :])[0]
                return result.squeeze(axis=0)
            return jax.jacfwd(model_call)(single_input)
        return jax.jacfwd(predict_single)(my_model,single_input)

    return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)