Closed KaiyueDuan closed 1 month ago
You can just use model.stateless_call(trainable_variables, non_trainable_variables, *args)
(args
are the model/layer's call arguments). This is a pure JAX function.
Check out this example. https://github.com/keras-team/keras/blob/master/examples/demo_custom_jax_workflow.py
Thank you, François! Your suggestion worked perfectly for my program. I successfully computed the Jacobian matrix using model.stateless_call()
. Here is the code I used:
def func_to_diff(x):
x = x[None, :]
return model.stateless_call(trainable_variables, non_trainable_variables, x)[0]
def jac_fwd_lambda(single_input):
return jax.jacfwd(func_to_diff)(single_input)
jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)
Hi,
I encountered a new issue. I noticed that in sample code https://github.com/keras-team/keras/blob/master/examples/demo_custom_jax_workflow.py, the model
is a global variable.
However, in my code, I need to pass my model as a function parameter; here is my code:
from functools import partial
key = jax.random.PRNGKey(42)
tf2jax.update_config('strict_shape_check', False)
@jax.jit
def _jax_predict_tf_single(model_state, single_point):
return model_state.apply_fn( model_state.params, single_point)[0]
@partial(jax.jit, static_argnums=0)
def predict_jax_single(my_model,single_point):
return my_model.stateless_call(my_model.trainable_variables,my_model.non_trainable_variables,single_point[None, :])[0].squeeze(axis=0)
def f_jacfwd(predict_single,my_model,input_data):
# mode 0: flax; mode 1: tf; mode 2: jax; otherwise: error
def jac_fwd_lambda(single_input):
return jax.jacfwd(predict_single)(my_model,single_input)
return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)
input_data = np.ones((5,3))
ret = f_jacfwd(predict_jax_single, model, input_data)
The code throws me an error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in <cell line: 24>()
22
23 input_data = np.ones((5,3))
---> 24 ret = f_jacfwd(predict_jax_single, model, input_data)
2 frames
[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in f_jacfwd(predict_single, my_model, input_data)
19 return jax.jacfwd(predict_single)(my_model,single_input)
20
---> 21 return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)
22
23 input_data = np.ones((5,3))
[... skipping hidden 3 frame]
[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in jac_fwd_lambda(single_input)
17
18 def jac_fwd_lambda(single_input):
---> 19 return jax.jacfwd(predict_single)(my_model,single_input)
20
21 return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)
[... skipping hidden 4 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in check_arg(arg)
279 def check_arg(arg: Any):
280 if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
--> 281 raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
282 "JAX type.")
283
TypeError: Argument '<Functional name=functional, built=True>' of type <class 'keras.src.models.functional.Functional'> is not a valid JAX type.
And I found that the type of model.stateless_call
is
<class 'method'>
{'__wrapped__': <function Layer.stateless_call at 0x795fc6fdfa30>}
, which seems not to be a Jax function.
So I am wondering how to pass the model as a parameter. Here is my notebook https://colab.research.google.com/drive/1nV8oIn4TzgmtcAk1xFaaFg4EnxLN9n4c?usp=sharing
Many thanks!
Update: I solve this issue by defining a inner function model_call
, which uses the local variable my_model
:
@partial(jax.jit, static_argnums=(0,1))
def f_jacfwd(predict_single,my_model,input_data):
def jac_fwd_lambda(single_input):
if "jax_single" in predict_single.__name__:
def model_call(input_val):
result = my_model.stateless_call(my_model.trainable_variables, my_model.non_trainable_variables, input_val[None, :])[0]
return result.squeeze(axis=0)
return jax.jacfwd(model_call)(single_input)
return jax.jacfwd(predict_single)(my_model,single_input)
return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)
I am using TensorFlow as the backend to train a DNN model. After training, I successfully converted the inference method of the Keras model to a JAX function and created a Flax training_state to perform inference using Flax. This workflow is working well. Here is my notebook.
However, when I switch to using JAX as the backend, I am unsure how to convert the inference method of the Keras model into a JAX function. Furthermore, I am also unclear about the steps needed to create a Flax training_state afterwards.
Could anyone provide guidance on how to achieve this? Any help would be greatly appreciated!