google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

[jax2tf] Unable to wrap tf.function + convert + call_tf when using tf.Variable / stateful call_tf #18315

Open areiner222 opened 10 months ago

areiner222 commented 10 months ago

Description

Having trouble nesting a call_tf inside a converted jax function when using a stateful transform / tf.Variable.

import jax
from jax.experimental import jax2tf
import tensorflow as tf

layer = tf.keras.layers.Dense(10, dtype=tf.float32)
layer.build((None, 8))

def func_tf(x):
    return layer(x)

def func_jax(x):
  return jax.numpy.sin(jax2tf.call_tf(func_tf)(x))

@tf.function(autograph=False, jit_compile=False)
def outer_func_tf(x):
    return jax2tf.convert(func_jax)(x)

x = tf.random.uniform((10, 8))
outer_func_tf(x)

Error is:

WARNING:absl:call_tf works best with a TensorFlow function that does not capture variables or tensors from the context. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. The following captures were found [<tf.Tensor: shape=(), dtype=resource, value=<ResourceHandle(name="Resource-70-at-0x600007088e60", device="/job:localhost/replica:0/task:0/device:CPU:0", container="Anonymous", type="tensorflow::Var", dtype and shapes : "[ DType enum: 1, Shape: [8,10] ]")>>, <tf.Tensor: shape=(), dtype=resource, value=<ResourceHandle(name="Resource-71-at-0x6000070895e0", device="/job:localhost/replica:0/task:0/device:CPU:0", container="Anonymous", type="tensorflow::Var", dtype and shapes : "[ DType enum: 1, Shape: [10] ]")>>]
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[163], line 21
     19 # Calls `cos_tf` in TF eager mode
     20 x = tf.random.uniform((10, 8))
---> 21 outer_func_tf(x)

File ~/mambaforge/envs//lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

Cell In[163], line 17, in outer_func_tf(x)
     15 @tf.function(autograph=False, jit_compile=False)
     16 def outer_func_tf(x):
---> 17     return jax2tf.convert(func_jax)(x)

File ~/mambaforge/envs//lib/python3.10/site-packages/jax/experimental/jax2tf/jax2tf.py:402, in convert.<locals>.converted_fun_tf(*args_tf, **kwargs_tf)
    396   impl = GraphSerializationImpl(
    397       fun_jax,
    398       args_specs=args_specs, kwargs_specs=kwargs_specs,
    399       args_flat_tf=args_flat_tf,
    400       enable_xla=enable_xla)
    401 try:
--> 402   impl.before_conversion()
    404   outs_tree: tree_util.PyTreeDef = None  # type: ignore
    405   if with_gradient:

File ~/mambaforge/envs//lib/python3.10/site-packages/jax/experimental/jax2tf/jax2tf.py:503, in NativeSerializationImpl.before_conversion(self)
    500   _thread_local_state.call_tf_concrete_function_list = _prev_func_list
    502 self._restore_context = _restore_context
--> 503 self.exported = export.export(
    504     self.fun_jax,
    505     lowering_platform=self.lowering_platform,
    506     disabled_checks=self.native_serialization_disabled_checks
    507 )(*self.args_specs, **self.kwargs_specs)

File ~/mambaforge/envs//lib/python3.10/site-packages/jax/experimental/export/export.py:422, in export.<locals>.do_export(*args_specs, **kwargs_specs)
    420 prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions
    421 shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
--> 422 lowered = wrapped_fun_jax.lower(
    423     *args_specs, **kwargs_specs,
    424     _experimental_lowering_parameters=mlir.LoweringParameters(
    425       platforms=lowering_platforms,
    426     ))
    428 lowering = lowered._lowering  # type: ignore
    429 _check_lowering(lowering)

    [... skipping hidden 12 frame]

File ~/mambaforge/envs//lib/python3.10/site-packages/jax/experimental/jax2tf/call_tf.py:497, in _call_tf_lowering(ctx, platform, function_flat_tf, args_flat_sig_tf, has_side_effects, ordered, call_tf_graph, output_avals, *args_op, **_)
    494     else:
    495       captured_inputs.append(inp)
--> 497 captured_ops = tuple(
    498     mlir.ir_constant(np.asarray(inp))
    499     for inp in captured_inputs
    500 )
    502 if call_tf_graph:
    503   with jax2tf_internal.inside_call_tf():

File ~/mambaforge/envs//lib/python3.10/site-packages/jax/experimental/jax2tf/call_tf.py:498, in <genexpr>(.0)
    494     else:
    495       captured_inputs.append(inp)
    497 captured_ops = tuple(
--> 498     mlir.ir_constant(np.asarray(inp))
    499     for inp in captured_inputs
    500 )
    502 if call_tf_graph:
    503   with jax2tf_internal.inside_call_tf():

NotImplementedError: numpy() is only available when eager execution is enabled.

What jax/jaxlib version are you using?

jax v0.4.17, jaxlib v0.4.17, tensorflow v2.14.0

Which accelerator(s) are you using?

CPU

Additional system info

Apple M1 Max

NVIDIA GPU info

No response

areiner222 commented 10 months ago

I found that keras_core (I believe soon to be Keras 3) provides a workaround for my example. In particular, the stateless_call method of keras_core layers makes it possible. Here is a contrived example:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import jax
from jax.experimental import jax2tf
import tensorflow as tf
import keras_core

class Jax2TFLayer(keras_core.layers.Layer):

    def __init__(self, units, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._layer = keras_core.layers.Dense(units)

    def build(self, input_shape):
        super().build(input_shape)
        self._layer.build(input_shape)

    def _call_tf(self, x, trainable_variables, non_trainable_variables):
        return self._layer.stateless_call(trainable_variables, non_trainable_variables, x)[0]  

    def _call_jax(self, x, trainable_variables, non_trainable_variables):
        length = 100
        def f(x, i):
            out = jax2tf.call_tf(self._call_tf)(x, trainable_variables, non_trainable_variables)
            return x, out
        return jax.lax.scan(f, x, jp.arange(length), length=length)

    @tf.function(autograph=False, jit_compile=True)
    def call(self, x):
        return jax2tf.convert(self._call_jax)(
            x, [v.value for v in self._layer.trainable_variables], [v.value for v in self._layer.non_trainable_variables]
        )[1]

inp = tf.random.uniform((10, 16))
out = layer(inp)

This can work great for my use case! but can't use it just yet as I understand polymorphism in call_tf doesn't work with native_serialization which I currently require

gnecula commented 10 months ago

I will take a look at this repro, but in general the call_tf mechanism works best when there are not tf.Variables captured in the called function. It is best to refactor the TF code to pass the variable values in and out of the TF function.

gnecula commented 9 months ago

I looked at the original repro.

The problem here is that when call_tf processes func_tf is notices that func_tf captures two tf.Variables, and it tries to read the values of those variables using .numpy(). This works if the code runs in eager mode, but does not work when the outer code is under tf.function. You can verify this if you remove tf.function from your repro.

@maxwillzq Can you please take a look if there is some workaround?

maxwillzq commented 9 months ago

Sure.

and it tries to read the values of those variables using .numpy()

Could you point me where it call .numpy() in keras lib ? I will try to see if we can remove it.

gnecula commented 9 months ago

The call to numpy happens as part of np.asarray(inp) here. This works if we are in eager mode bug fails under tf.function

maxwillzq commented 9 months ago

@areiner222 As @gnecula 's suggestion, the best way is to rewrite the tf function so it has on captured tf.Variables. Another choice is https://www.tensorflow.org/api_docs/python/tf/config/run_functions_eagerly but it has performance penalty.

I test it on colab https://shorturl.at/bkwHJ and it works.

By the way, I try to use same way on jax2tf.py but it is much complicated. @gnecula, do you know where is the best location insert this context_manager call ? Thanks

gnecula commented 9 months ago

I do not think that we should change jax2tf to force running of TF functions eagerly.

If the user starts with a tf.keras.layer, which references variables, is it possible to split that function into a set of variable values and a function that takes the variables as inputs?

Edgeworth commented 6 months ago

It would be nice if there were some other way of resolving this. I am loading a saved model (originally JAX code) to run inside more JAX code, which I then want to export into another savedmodel, but it is tricky to use call_tf on the inner savedmodel while also passing in all variables / etc it uses (since they are tf tensors).

Edgeworth commented 5 months ago

For resolving this when functions are not run eagerly, what about modifying _call_tf_lowering to grab the value of variables? This is what I had to do to get my use case working for https://github.com/google/jax/issues/11753

  if tf.executing_eagerly():
    np_captured_inputs = [np.asarray(inp) for inp in captured_inputs]
  else:
    if captured_inputs:
      with tf.compat.v1.Session(graph=captured_inputs[0].graph) as sess:
          sess.run([v.initializer for v in captured_inputs])
          np_captured_inputs = sess.run(captured_inputs)
    else:
      np_captured_inputs = []

  captured_ops = tuple(
      mlir.ir_constant(inp)
      for inp in np_captured_inputs
  )