Open dionhaefner opened 1 year ago
Having a similar issue when trying to follow the README here: https://github.com/google/jax/tree/main/jax/experimental/jax2tf/examples/serving
Getting this error running the python command
TypeError: call_module() got an unexpected keyword argument 'function_list'
Full output:
File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_main.py", line 210, in <module>
app.run(lambda _: train_and_save())
File "/Users/ako/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/Users/ako/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_main.py", line 210, in <lambda>
app.run(lambda _: train_and_save())
File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_main.py", line 125, in train_and_save
saved_model_lib.convert_and_save_model(
File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_lib.py", line 115, in convert_and_save_model
tf_graph.get_concrete_function(input_signatures[0])
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1189, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1169, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 694, in _initialize
self._variable_creation_fn # pylint: disable=protected-access
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 176, in _get_concrete_function_internal_garbage_collected
concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 171, in _maybe_define_concrete_function
return self._maybe_define_function(args, kwargs)
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 398, in _maybe_define_function
concrete_function = self._create_concrete_function(
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 305, in _create_concrete_function
func_graph_module.func_graph_from_py_func(
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1055, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 597, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_lib.py", line 108, in <lambda>
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
File "/Users/ako/work/jax/jax/experimental/jax2tf/jax2tf.py", line 417, in converted_fun_tf
outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/ops/custom_gradient.py", line 343, in __call__
return self._d(self._f, a, k)
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/ops/custom_gradient.py", line 299, in decorated
return _graph_mode_decorator(wrapped, args, kwargs)
File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/ops/custom_gradient.py", line 425, in _graph_mode_decorator
result, grad_fn = f(*args)
File "/Users/ako/work/jax/jax/experimental/jax2tf/jax2tf.py", line 409, in converted_fun_flat_with_custom_gradient_tf
outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf)
File "/Users/ako/work/jax/jax/experimental/jax2tf/jax2tf.py", line 518, in run_fun_tf
results = _run_exported_as_tf(args_flat_tf, self.exported)
File "/Users/ako/work/jax/jax/experimental/jax2tf/jax2tf.py", line 877, in _run_exported_as_tf
res = tfxla.call_module(args_flat_tf, **call_module_attrs)
TypeError: call_module() got an unexpected keyword argument 'function_list'
Using same jax version, CPU, Macbook M2
Just downgraded jax to v0.4.12/13/14 and it is passing on all of them (I started at a way too old version oops), but issue starts from jax[cpu]==v0.4.15.dev20230919
In general, jax2tf
tests require the nightly tensorflow release (pip install tf-nightly
). They often get out of sync with the tensorflow release.
Thanks - tried that, same result.
$ pip freeze | grep "tf-nightly"
tf-nightly==2.15.0.dev20230919
tf-nightly-macos==2.15.0.dev20230919
call_module
accepts a function_list
argument as of four months ago: https://github.com/tensorflow/tensorflow/blame/a90eb068e805fbe39ccfd7f4bfc2e33dd8a592a0/tensorflow/compiler/tf2xla/python/xla.py#L637
Can you double check that the Python executable you're running is picking up the nightly tensorflow listed by pip
? For example:
$ python -c "import tensorflow; print(tensorflow.__version__)"
2.15.0-dev20230919
Thanks, that was it. The problem is that tensorflow_serving_api
depends on tensorflow
, so that gets installed along with tf-nightly
and is found first.
A workaround is to do this instead:
$ pip install matplotlib flax jaxlib tensorflow_datasets tensorflow_serving_api -e jax
$ pip install tf-nightly
Maybe the jax2tf serving guide should recommend that instead?
Description
Output:
What jax/jaxlib version are you using?
jax v0.4.16 jaxlib v0.4.16
Which accelerator(s) are you using?
CPU
Additional system info
OSX
NVIDIA GPU info
No response