jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

[call_tf] implementation of batching rule for vmap #11753

Open DavidDevoogdt opened 2 years ago

DavidDevoogdt commented 2 years ago

Hi,

I'm trying to use call_tf in combination with jacrev and hence vmap. The primitive_batcher of type call_tf_p is not implemented. I tried to modify the source code of jax/experimental/jax2tf/call_tf.py by adding the generic broadcaster batching.defbroadcasting(call_tf_p) but this doesn't work (wrong result for vmap, error for jacrev).

I'm not sure how to finish this as I'm not familiar with the internals of jax:

def _tf_batcher(args, dims, callable_flat_tf, function_flat_tf, args_flat_sig_tf, **params):
    flat_shape = np.array(args_flat_sig_tf[0].shape)
    ## do smthg, probably call bind as many times with slice according to batched dim
    call_tf_p.bind(args, callable_flat_tf=callable_flat_tf)

batching.primitive_batchers[call_tf_p] = _tf_batcher

As a test case, I'd would like this tensorflow roundtrip code to work as expected.

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from jax import jacrev, jit, vmap
from jax.experimental import jax2tf

def test_call_tf_batcher():

    @jit
    def f(x):
        return jnp.sum(x * x)

    f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))

    x = np.random.random((7))
    y = np.random.random((5, 7))

    # test f==f_rt

    print(f"{f(x)}, {f_rt(x)}")
    print(f"{jax.grad(f)(x)}, {jax.grad(f_rt)(x)}")

    # test vmap
    f_rt_v = vmap(f_rt)
    f_v = vmap(f)
    print(f"{f_v(y)} == {f_rt_v(y)}")

    # test jacobian
    j = jacrev(f)
    j_rt = jacrev(f_rt)
    print(f"{j(x)} == {j_rt(x)}")

if __name__ == "__main__":
    test_call_tf_batcher()

Any help would be greatly appreciated

gnecula commented 2 years ago

In general, it is not possible to add a proper batching rule for call_tf because the called function can in principle be an arbitrary TF function for which one cannot define a generic batching rule (except perhaps by running the function in a loop). Even in the case when the called function originates from JAX, by the time it gets to call_tf the function has been staged out to HLO so it is not possible to use the JAX batching rules anymore.

For a round-trip to TF, you have to apply the vmap before converting to TF.

What is the use case for which you want to do a round trip through TensorFlow?

DavidDevoogdt commented 2 years ago

I don't really need a round trip, just code to convert a tensorflow function to to jax function f: R^(nx3)->R^m (i.e a m dimensional function of n 3d molecular coordinates). I need the jacobian matrix. I've managed to implement the batching procedure as a loop:

def loop_batcher(prim, args, dims,  **params):

    # determine new axis size
    for ni, di in enumerate(dims):
        if di is not None:
            axis_size = args[ni].shape[di]
            break

    # generate combination of indices for different arguments
    args_indices = []
    for xi, ia in zip(args, dims):

        xs = xi.shape
        indices = [slice(None) if i !=
                   ia else None for i, _ in enumerate(xs)]

        l = []
        for i in range(axis_size):
            if ia is not None:
                indices[ia] = i

            l.append(tuple(indices))
        args_indices.append(l)

    # apply function
    out = []
    for inds in list(zip(*args_indices),):
        outp = prim.bind(*[a[b] for a, b in zip(
            args, inds)], **params)
        if not isinstance(outp, Iterable):
            outp = tuple((outp,))
        out.append(outp)

    ret = []

    # collect output in arrays
    for out_args in list(zip(*out)):
        val = jax.numpy.hstack(out_args)
        val = jax.numpy.reshape(val,  (axis_size,  *out_args[0].shape))

        ret.append(val)

    return (ret,  (0,)*len(ret))

batching.primitive_batchers[call_tf_p] = functools.partial(
    loop_batcher, call_tf_p)

Below the testing code

import jax.numpy as jnp
import numpy as np
from jax import grad, jacrev, jit, vmap
from jax.experimental import jax2tf

def test_call_tf_batcher():

    @jit
    def f(x):
        return jnp.array([jnp.sum(x * x),  jnp.product(x)])

    f_t = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))

    x = np.random.random((7))
    y = np.random.random((5, 7))

    print(f"||f(x) - f_t(x)||_2 = { jnp.linalg.norm( f(x)-f_t(x) )  }")

    # test vmap
    f_v = jit(vmap(f))
    f_v_t = jit(vmap(f_t))

    print(f"||f(y) - f_t(y)||_2 = { jnp.linalg.norm( f_v(y)-f_v_t(y) )  }")

    j = jit(jacrev(f))
    j_t = jit(jacrev(f_t))

    print(
        f"||jac(f)(x) - jac(f_t)(x)||_2 = { jnp.linalg.norm( j(x)-j_t(x) )  }")

if __name__ == "__main__":
    test_call_tf_batcher()

The code seems to work fine for vmap and jacobian and it can be jit compiled. In my case the performance penalty is low because the output dimension m is low (2-3). As said before, I don't know the internals of jax so there might still be some errors in my code. This is quite general so it might also be used in other places where batching is not implemented :)

gnecula commented 2 years ago

Your solution should work in principle (I have not checked it in all detail), but I do not feel that it is a solution that we want to upstream to JAX because it breaks the expectation that jax.vmap is generally more efficient than running the underlying function repeatedly for all mapped elements.

One thing we could do is to add a reference to this issue and your solution in the documentation of call_tf and let the users decide whether this workaround works for them or not. WDYT?

DavidDevoogdt commented 2 years ago

Seems good to me. Another option would be to add a configurable flag SLOW_TF_VMAP = False. If the flag is not set to True, raise an error within batching function with a link to the right webpage.

p.s. there was a small error in the code, I changed my last post. The code seems to work now for the more complicated problem I tried to solve in the first place :)

DavidDevoogdt commented 2 years ago

Some parts of my code were annoyingly slow, so I've reimplemented the batcher using the tf.vectorized_map function. Basically I redo the call_tf on the vectorized tensoflow function and pipe the results to jax. The unmapped dimemsions are applied beforehand.

def loop_batcher(prim, args, dims,  **params):

    # do partial application of args on given position"
    def apply(batch_args, f, static_args, static_pos):
        arguments = []

        l = len(batch_args)+len(static_args)
        j = 0
        k = 0
        for i in range(l):
            if i in static_pos:
                arguments.append(static_args[j])
                j += 1
            else:
                arguments.append(batch_args[k])
                k += 1

        return f(*arguments)

    static_pos = []
    static_args = []
    batch_args = []

    # find arguments for partial application
    for i, (arg, batch_axis) in enumerate(zip(args, dims)):
        if batch_axis is None:
            static_pos.append(i)
            static_args.append(arg)
        else:
            assert batch_axis == 0, 'other position not yet implemented'
            batch_args.append(arg)

    # vectorize
    def par_fun(batch_args, static_args):
        return tf.vectorized_map(fn=functools.partial(apply, f=params['callable_flat_tf'], static_args=static_args, static_pos=static_pos), elems=batch_args)

    if len(batch_args) != 1:
        raise NotImplementedError

    # execute with given arguments
    ret = call_tf(par_fun)(batch_args, static_args)
    return (ret,  (0,)*len(ret))

What are your thoughts on this? The function needs some more work if the batch dimensions are not in position 0 or more than one array should be batched at once.

Edgeworth commented 7 months ago

In case anyone is interested, I was also looking at this since I wanted to load, run and save JAX models I previously saved to tf SavedModels and repeat that an arbitrary number of times - that is, composing savedmodels. My use case only needs a single batch dimension as the first dimension but may be possible to adapt the code:

@dataclass(eq=True, kw_only=True, order=True, frozen=True)
class _ShapeAndDtype:
    shape: tuple
    dtype: type

# Passthrough batcher for tf saved models. Assumes the first dimension is batched and no other.
def _tf_passthrough_batcher(
    fn: Callable,
    inp: JaxArrayOrMap,
    batched_args: tuple,
    batched_dims: tuple,
    call_tf_graph: bool,
    callable_flat_tf: Callable[[list[TfVal]], Sequence[TfVal]],
    **_kwargs: dict,
) -> tuple:
    assert len(batched_dims) == 1
    assert len(batched_args) == 1
    treedef = jax.tree_structure(inp)
    # Map non-integers to 1 to handle polymorphic inputs, e.g. on save.
    input_shape = tuple([int(v) if isinstance(v, int) else 1 for v in batched_args[0].shape])
    # Force call callable_flat_tf to fill `res_treedef` inside it.
    out = callable_flat_tf(np.zeros(input_shape))  # type: ignore[arg-type]
    assert len(out) == 1
    output_shape: list[int] = list(out[0].shape)
    # Grab value which may be non-integer (polymorphic) from the batch dimension.
    batched_output_shape = (batched_args[0].shape[0], *output_shape[1:])
    # Assumes a single output.
    output_shape_dtype = _ShapeAndDtype(shape=batched_output_shape, dtype=np.float32)
    args = treedef.unflatten(batched_args)
    ret = jax2tf.call_tf(fn, call_tf_graph=call_tf_graph, output_shape_dtype=output_shape_dtype)(
        args
    )
    return ([ret], (0,))

It basically just passes through the batch operation since it assumes that the tf SavedModel you're using has a polymorphic input on the batch dimension (which is true for my use case). You need to invoke it like this:

batching.primitive_batchers[call_tf_p] = functools.partial(_tf_passthrough_batcher, fn, inp)
# call_tf_graph supports polymorphic inputs for saving to a SavedModel.
# But, it does not work for training / running at all, just for saving.
jax2tf.call_tf(fn, call_tf_graph=is_saving)(inp)

There is some additional complication around saving vs running/training a model that contains the loaded SavedModel. For training/running, you want to use eager execution (the default), otherwise it won't work (don't remember the exact reason). For saving, I want to preserve the polymorphic input dimension (batch dimension) - trying to save with eager execution gives an error:

ValueError: Error compiling TensorFlow function (see below for the caught exception).
call_tf can used in a staged context (under jax.jit, lax.scan, etc.) only with compilable functions with static output shapes.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.

So, you need to use the experimental call_tf_graph flag only for saving the model. This doesn't work with eager execution, since it needs the tensorflow graph object, so we need to disable eager execution when saving and enable call_tf_graph for call_tf:

tf.compat.v1.disable_eager_execution()

Now it is not using eager execution, it needs to access a bunch of tf variables from the saved model we loaded as part of our larger model, but they are not initialized, so when we export using orbax we have to provide the extra trackable tf resources.

jax_module = JaxModule(
    ...,
    {
        "predict": ...
    },
    # Our jax code is polymorphic on the batch size, so let the export do that too.
    input_polymorphic_shape={"predict": "(b, ...)"},
)

serving_configs = [
    ServingConfig(
        ...
        extra_trackable_resources=extra_trackable_resources,
    )
]

The extra trackable resources need to be the variables from the saved tf model we loaded, for example:

self.model.signatures["serving_default"].variables

Then to actually export, we need to set up a tensorflow session with the extra trackable resources, and also the tf variables the jax module needs (e.g. from the larger model we embedded the existing saved model in) initialized:

with tf.compat.v1.Session(
    graph=extra_trackable_resources[0].graph
).as_default() as sess:
    # Run initializers
    if extra_trackable_resources:
        sess.run([v.initializer for v in extra_trackable_resources])
    sess.run([v.initializer for v in jax_module.variables])

    export_mgr = ExportManager(jax_module, serving_configs=serving_configs)
    export_mgr.save(path)

Unfortunately this isn't quite enough because call_tf doesn't support non-eager execution (see also https://github.com/google/jax/issues/18315 ), so we need to patch _call_tf_lowering to get the variables out of the graph if we're not executing eagerly:

  if tf.executing_eagerly():
    np_captured_inputs = [np.asarray(inp) for inp in captured_inputs]
  else:
    if captured_inputs:
      with tf.compat.v1.Session(graph=captured_inputs[0].graph) as sess:
          # Get all global variables within this graph
          # Run initializers
          sess.run([v.initializer for v in captured_inputs])
          # Get values
          np_captured_inputs = sess.run(captured_inputs)
    else:
      np_captured_inputs = []

  captured_ops = tuple(
      mlir.ir_constant(inp)
      for inp in np_captured_inputs
  )

And also there's a small bug in jax_module.py _to_tf_variable where it does not pass in the correct device name, need to add .name to default_cpu_device.

with tf.device(default_cpu_device.name):
    return tf.Variable(
        x, trainable=trainable, shape=x.shape, dtype=x.dtype, name=name
    )

With this it's possible to round trip a jax model with a batch dimension to tf savedmodel, then load it back up again and train it as part of a larger model and save it again, preserving the batch dimension, so the batching can be closed under saving.

Unrelated, but _call_tf_lowering has an implicit dependency on tf >= 2.16 since it passes the platform argument to experimental_get_compiler_ir. I don't think this dependency is documented anywhere.