jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.08k stars 2.75k forks source link

[jax2tf] Tests are failing #17660

Open dionhaefner opened 1 year ago

dionhaefner commented 1 year ago

Description

$ cd /tmp
$ git clone https://github.com/google/jax.git -b jax-v0.4.16
$ pip install matplotlib flax jaxlib tensorflow_datasets tensorflow_serving_api tensorflow -e jax
$ python jax/jax/experimental/jax2tf/examples/saved_model_main_test.py

Output:

Running tests under Python 3.10.12: /Users/dion/.virtualenvs/tempenv-6bbe5525767/bin/python
[ RUN      ] SavedModelMainTest.test_train_and_save_features_mnist_flax
I0919 14:49:21.999132 8342069376 tf_test_util.py:172] Running jax2tf converted code on LogicalDevice(name='/device:CPU:0', device_type='CPU').
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695127762.000160       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
I0919 14:49:22.000297 8342069376 xla_bridge.py:513] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0919 14:49:22.000332 8342069376 xla_bridge.py:513] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0919 14:49:22.002579 8342069376 xla_bridge.py:513] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
[  FAILED  ] SavedModelMainTest.test_train_and_save_features_mnist_flax
[ RUN      ] SavedModelMainTest.test_train_and_save_features_mnist_pure_jax
I0919 14:49:22.003413 8342069376 tf_test_util.py:172] Running jax2tf converted code on LogicalDevice(name='/device:CPU:0', device_type='CPU').
[  FAILED  ] SavedModelMainTest.test_train_and_save_features_mnist_pure_jax
[ RUN      ] SavedModelMainTest.test_train_and_save_full_mnist_flax_batch=-1
I0919 14:49:22.003628 8342069376 tf_test_util.py:172] Running jax2tf converted code on LogicalDevice(name='/device:CPU:0', device_type='CPU').
[  FAILED  ] SavedModelMainTest.test_train_and_save_full_mnist_flax_batch=-1
[ RUN      ] SavedModelMainTest.test_train_and_save_full_mnist_flax_batch=1
I0919 14:49:22.003809 8342069376 tf_test_util.py:172] Running jax2tf converted code on LogicalDevice(name='/device:CPU:0', device_type='CPU').
[  FAILED  ] SavedModelMainTest.test_train_and_save_full_mnist_flax_batch=1
[ RUN      ] SavedModelMainTest.test_train_and_save_full_mnist_pure_jax_batch=-1
I0919 14:49:22.003971 8342069376 tf_test_util.py:172] Running jax2tf converted code on LogicalDevice(name='/device:CPU:0', device_type='CPU').
[  FAILED  ] SavedModelMainTest.test_train_and_save_full_mnist_pure_jax_batch=-1
[ RUN      ] SavedModelMainTest.test_train_and_save_full_mnist_pure_jax_batch=1
I0919 14:49:22.004137 8342069376 tf_test_util.py:172] Running jax2tf converted code on LogicalDevice(name='/device:CPU:0', device_type='CPU').
[  FAILED  ] SavedModelMainTest.test_train_and_save_full_mnist_pure_jax_batch=1
======================================================================
ERROR: test_train_and_save_features_mnist_flax (__main__.SavedModelMainTest)
SavedModelMainTest.test_train_and_save_features_mnist_flax
test_train_and_save_features_mnist_flax(model='mnist_flax')
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/examples/saved_model_main_test.py", line 33, in setUp
    super().setUp()
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/tests/tf_test_util.py", line 186, in setUp
    tfxla.call_module_maximum_supported_version())
AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'call_module_maximum_supported_version'

======================================================================
ERROR: test_train_and_save_features_mnist_pure_jax (__main__.SavedModelMainTest)
SavedModelMainTest.test_train_and_save_features_mnist_pure_jax
test_train_and_save_features_mnist_pure_jax(model='mnist_pure_jax')
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/examples/saved_model_main_test.py", line 33, in setUp
    super().setUp()
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/tests/tf_test_util.py", line 186, in setUp
    tfxla.call_module_maximum_supported_version())
AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'call_module_maximum_supported_version'

======================================================================
ERROR: test_train_and_save_full_mnist_flax_batch=-1 (__main__.SavedModelMainTest)
SavedModelMainTest.test_train_and_save_full_mnist_flax_batch=-1
test_train_and_save_full_mnist_flax_batch=-1(model='mnist_flax', serving_batch_size=-1)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/examples/saved_model_main_test.py", line 33, in setUp
    super().setUp()
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/tests/tf_test_util.py", line 186, in setUp
    tfxla.call_module_maximum_supported_version())
AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'call_module_maximum_supported_version'

======================================================================
ERROR: test_train_and_save_full_mnist_flax_batch=1 (__main__.SavedModelMainTest)
SavedModelMainTest.test_train_and_save_full_mnist_flax_batch=1
test_train_and_save_full_mnist_flax_batch=1(model='mnist_flax', serving_batch_size=1)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/examples/saved_model_main_test.py", line 33, in setUp
    super().setUp()
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/tests/tf_test_util.py", line 186, in setUp
    tfxla.call_module_maximum_supported_version())
AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'call_module_maximum_supported_version'

======================================================================
ERROR: test_train_and_save_full_mnist_pure_jax_batch=-1 (__main__.SavedModelMainTest)
SavedModelMainTest.test_train_and_save_full_mnist_pure_jax_batch=-1
test_train_and_save_full_mnist_pure_jax_batch=-1(model='mnist_pure_jax', serving_batch_size=-1)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/examples/saved_model_main_test.py", line 33, in setUp
    super().setUp()
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/tests/tf_test_util.py", line 186, in setUp
    tfxla.call_module_maximum_supported_version())
AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'call_module_maximum_supported_version'

======================================================================
ERROR: test_train_and_save_full_mnist_pure_jax_batch=1 (__main__.SavedModelMainTest)
SavedModelMainTest.test_train_and_save_full_mnist_pure_jax_batch=1
test_train_and_save_full_mnist_pure_jax_batch=1(model='mnist_pure_jax', serving_batch_size=1)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/examples/saved_model_main_test.py", line 33, in setUp
    super().setUp()
  File "/Users/dion/codes/jax/jax/experimental/jax2tf/tests/tf_test_util.py", line 186, in setUp
    tfxla.call_module_maximum_supported_version())
AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'call_module_maximum_supported_version'

----------------------------------------------------------------------
Ran 6 tests in 0.006s

FAILED (errors=6)
I0000 00:00:1695127762.198402       1 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

What jax/jaxlib version are you using?

jax v0.4.16 jaxlib v0.4.16

Which accelerator(s) are you using?

CPU

Additional system info

OSX

NVIDIA GPU info

No response

angela-ko commented 1 year ago

Having a similar issue when trying to follow the README here: https://github.com/google/jax/tree/main/jax/experimental/jax2tf/examples/serving

Getting this error running the python command TypeError: call_module() got an unexpected keyword argument 'function_list'

Full output:

  File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_main.py", line 210, in <module>
    app.run(lambda _: train_and_save())
  File "/Users/ako/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/ako/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_main.py", line 210, in <lambda>
    app.run(lambda _: train_and_save())
  File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_main.py", line 125, in train_and_save
    saved_model_lib.convert_and_save_model(
  File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_lib.py", line 115, in convert_and_save_model
    tf_graph.get_concrete_function(input_signatures[0])
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1189, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1169, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 694, in _initialize
    self._variable_creation_fn    # pylint: disable=protected-access
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 176, in _get_concrete_function_internal_garbage_collected
    concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 171, in _maybe_define_concrete_function
    return self._maybe_define_function(args, kwargs)
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 398, in _maybe_define_function
    concrete_function = self._create_concrete_function(
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 305, in _create_concrete_function
    func_graph_module.func_graph_from_py_func(
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1055, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 597, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/Users/ako/work/jax/jax/experimental/jax2tf/examples/saved_model_lib.py", line 108, in <lambda>
    tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
  File "/Users/ako/work/jax/jax/experimental/jax2tf/jax2tf.py", line 417, in converted_fun_tf
    outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/ops/custom_gradient.py", line 343, in __call__
    return self._d(self._f, a, k)
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/ops/custom_gradient.py", line 299, in decorated
    return _graph_mode_decorator(wrapped, args, kwargs)
  File "/Users/ako/venv/lib/python3.10/site-packages/tensorflow/python/ops/custom_gradient.py", line 425, in _graph_mode_decorator
    result, grad_fn = f(*args)
  File "/Users/ako/work/jax/jax/experimental/jax2tf/jax2tf.py", line 409, in converted_fun_flat_with_custom_gradient_tf
    outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf)
  File "/Users/ako/work/jax/jax/experimental/jax2tf/jax2tf.py", line 518, in run_fun_tf
    results = _run_exported_as_tf(args_flat_tf, self.exported)
  File "/Users/ako/work/jax/jax/experimental/jax2tf/jax2tf.py", line 877, in _run_exported_as_tf
    res = tfxla.call_module(args_flat_tf, **call_module_attrs)
TypeError: call_module() got an unexpected keyword argument 'function_list'

Using same jax version, CPU, Macbook M2

angela-ko commented 1 year ago

Just downgraded jax to v0.4.12/13/14 and it is passing on all of them (I started at a way too old version oops), but issue starts from jax[cpu]==v0.4.15.dev20230919

jakevdp commented 1 year ago

In general, jax2tf tests require the nightly tensorflow release (pip install tf-nightly). They often get out of sync with the tensorflow release.

dionhaefner commented 1 year ago

Thanks - tried that, same result.

$ pip freeze | grep "tf-nightly"
tf-nightly==2.15.0.dev20230919
tf-nightly-macos==2.15.0.dev20230919
jakevdp commented 1 year ago

call_module accepts a function_list argument as of four months ago: https://github.com/tensorflow/tensorflow/blame/a90eb068e805fbe39ccfd7f4bfc2e33dd8a592a0/tensorflow/compiler/tf2xla/python/xla.py#L637

Can you double check that the Python executable you're running is picking up the nightly tensorflow listed by pip? For example:

$ python -c "import tensorflow; print(tensorflow.__version__)"
2.15.0-dev20230919
dionhaefner commented 1 year ago

Thanks, that was it. The problem is that tensorflow_serving_api depends on tensorflow, so that gets installed along with tf-nightly and is found first.

A workaround is to do this instead:

$ pip install matplotlib flax jaxlib tensorflow_datasets tensorflow_serving_api -e jax
$ pip install tf-nightly

Maybe the jax2tf serving guide should recommend that instead?