patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.41k stars 126 forks source link

Make diffeqsolve convertable to TensorFlow #202

Open llandsmeer opened 1 year ago

llandsmeer commented 1 year ago

Based on a talk on NODE's on youtube I came across this package, and this looks perfect for some project we are planning (thanks for the great talk!) . Now one of the platforms where we want to run our code does not support JAX/XLA/Tensorflow. Just ONNX. I tried converting a simulation function to Tensorflow for later conversion to ONNX, but this fails because the unsupported unvmap_any is used (at compiletime!) to deduce the amount of iterations needed.

Minimal example:

import tensorflow as tf
import jax.numpy as jnp
import tf2onnx

from diffrax import diffeqsolve, ODETerm, Euler
from jax.experimental import jax2tf

def simulate(y0):
    solution = diffeqsolve(
            terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
            t0=0, t1=1, dt0=0.1, y0=y0)
    return solution.ys[0]

# This works
x = simulate(100)
assert jnp.isclose(x, jnp.exp(-1)*100, atol=.1, rtol=.1)

simulate_tf = tf.function(jax2tf.convert(simulate, enable_xla=False))

# Does not work:
# simulate_tf(100)
# => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented

# Also doesn't not work:
tf2onnx.convert.from_function(
        simulate_tf, input_signature=[tf.TensorSpec((), tf.float32)])
# simulate_tf(100)
# => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented

For us, it would be really nice to use a GPU/TSP during training with jax, then transfer to this specifc piece of hardware with just ONNX support for inference (at this point I don't need gradient calculation anymore). Of course, solving this might be completely outside the scope of the project and there are other solutions like writing the solvers from scratch or using existing solvers in TF/PyTorch.

Currently my knowledge of JAX is limited (hopefully this will soon improve!). If this is the only function stopping Diffrax from being tensorflow-convertable maybe a small workaround could be possible. I'm also happy with a answer like 'no we don't do this' or 'send us a PR if you want to have this fixed'

patrick-kidger commented 1 year ago

Okay, this is an interesting feature request! You successfully nerd-sniped me into implementing this.

First of all, TL;DR: this is now available on the dev branch of Equinox as equinox.internal.to_onnx. See equinox/tests/test_onnx.py for example usage. Try installing Equinox from that branch and letting me know what you think?


More details as follows.

When JAX runs your code, it parses it into a jaxpr, which is basically a representation of the program that you want to run. A jaxpr is basically just a sequence of "primitives" (addition, cosines, matrix-multiplies, ...), as one long list of equations.

When you want to jax.vmap / jax.grad etc. your code, it works by performing a jaxpr-to-jaxpr transformation. It iterates through the jaxpr, and looks up a translation rule for each primitive. For example, vmap(jnp.dot) works by looking up the batching rule for a dot product, and returning the operations for performing a matrix-vector product. (Which is what a batched-dot-product is.) The result is a new jaxpr representing the transformed program.

This is quite a powerful paradigm, and is also how the JAX->TF export works: the exporter runs through the jaxpr of your program, and looks up a JAX->TF conversion rule for every primitive. The result is a TensorFlow program.

Now, JAX allows us to define custom primitives. Indeed, Diffrax uses a few -- one such is the unvmap_any you've spotted. Defining custom primitives is useful to be able to more carefully control how your program interacts with these JAX transformations (vmap, grad etc.) However, whilst library authors such as myself tend to implement all the primitive rules for vmap/grad etc., we often don't implement rules for more esoteric operations, such as the JAX->TF exporter.

So: one possible solution is to register a custom conversion rule for each such custom primitive. This is totally doable, but quite a lot of work.

Fortunately, there's a shortcut. All of the primitives used by Diffrax are actually just "wrapper primitives": they wrap other pieces of JAX code to provide them with custom behaviour with respect to vmap/grad/whatever. But the underlying operations are still just JAX operations, which we expect the JAX->TF exporter to be able to handle just fine.

Thus, https://github.com/patrick-kidger/equinox/pull/243 adds a new jaxpr->jaxpr transformation, "finalisation", equinox.internal.finalise_jaxpr, which simply ignores all this custom behaviour and replaces each custom primitive with whatever JAX operations it is wrapping.

This means that doing vmap(finalise_jaxpr(your stuff here)) would be wrong! We would have removed the custom behaviour our custom primitives set out to introduce -- and then vmap'd it.

However, as long as "finalisation" is, well, the final the thing you do: it means that you can now perform export_to_onnx(finalise_jaxpr(your stuff here))! The result of finalise_jaxpr(your stuff here) is something with all the custom wrapper primitives stripped out, and thus is something that the exporter knows how to handle.

And indeed if you look in the source code for equinox.internal.to_onnx, then pretty much the first thing it does is to "finalise" your code. If you wanted to then you could call the finalisation transformation yourself, and then just do the ONNX export in the same way you already are.


This is a very quick thing I just threw together, so I make no promises! In particular I've not tried it on any Diffrax code yet, and this definitely isn't stable API. But let me know how it goes.

llandsmeer commented 1 year ago

Wow this was quick. Thank you very much for putting in the effort! And the entire explanation on equinox/jax internals :) Mapping directly to ONNX is a great idea.

To test it, I ran

pip uninstall diffrax equinox
pip install -U https://github.com/patrick-kidger/diffrax/archive/main.zip
pip install -U https://github.com/patrick-kidger/equinox/archive/dev.zip

to update to the latest versions. Then, if I understood correctly, I should to_onnx() it as follows:

import tensorflow as tf
import jax.numpy as jnp
import tf2onnx

from diffrax import diffeqsolve, ODETerm, Euler
from jax.experimental import jax2tf

import equinox.internal as eqxi

def simulate(y0):
    solution = diffeqsolve(
            terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
            t0=0, t1=1, dt0=0.1, y0=y0
            )
    return solution.ys[0]

onnx_generator_fn = eqxi.to_onnx(simulate, wrapper_prims=[
    eqxi.unvmap_any_p
    ])

f = onnx_generator_fn(10)
print(f(100))

Which gives the error message:

    File "/home/llandsmeer/.local/lib/python3.10/site-packages/jax/experimental/jax2tf/jax2tf.py", line 1183, in get_primitive_impl
        raise NotImplementedError(msg.format(p)) from err

    NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented

which I think means we're missing a xla.register_translation() or mlir.register_lowering() in equinox/internal/unvmap.py?

Edit: no the register() calls look exactly like some other examples I can find online..

llandsmeer commented 1 year ago

(Some random dumps of my crude attempts at debugging this)

With some print statements I saw that the register_translation() thing got called.

The error originates here in jax2tf.py

  def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]:
    # Returns the primitive implementation and whether the implementation
    # takes abstract values (see definition of tf_impl_with_avals)
    if not _thread_local_state.enable_xla:
      try:
        return tf_impl_no_xla[p], True  # Always require avals.
      except KeyError:
        pass
    try:
      return tf_impl[p], False
    except KeyError:
      try:
        return tf_impl_with_avals[p], True
      except KeyError as err:
        msg = "TensorFlow interpretation rule for '{}' not implemented"
        raise NotImplementedError(msg.format(p)) from err

None of tf_impl, tf_impl_no_xla or tf_impl_with_avals contains anylike Primitive('unvmap_any').

Now if I add this (yes editing the jax2tf globals)...

from jax.experimental.jax2tf import jax2tf

def tf_error_impl(pred, index, *x, msgs):
    # This is bad
    return x

jax2tf.tf_impl[eqxi.unvmap_any_p] = tf.experimental.numpy.any
jax2tf.tf_impl[eqxi.unvmap_max_p] = tf.experimental.numpy.max
jax2tf.tf_impl[eqxi.branched_error_p] = tf_error_impl

I can seeminly get it to become a TF function (although not in the intended way? I guess finalise() was meant to remove the unvmap_any() calls? I'm very much in no-idea-what-I'm-doing terrain here).

Then I get bombarded with a list of tf2onnx errors, due to these two unsupported tf2onnx ops:

ERROR:tf2onnx.tfonnx:Unsupported ops: Counter({'_SwitchN': 15, 'Merge': 15})

...

patrick-kidger commented 1 year ago

Regarding the original error with unvmap_any: what was going on was that these were inside "higher order primitives", e.g. lax.scan/lax.cond, and the finalisation transform wasn't looking inside of these. That was just a bug on my end, and is fixed in https://github.com/patrick-kidger/equinox/pull/244. (Which also removes the wrapper_prims argument in favour of a registry of rules.)

Regarding your approach of modifying jax2tf.tf_impl: this is actually the expected approach for tackling this problem. jax2tf defines a new transformation (the JAX->TF exporter), and Equinox defines some new primitives. To be able to run one on the other, then someone still needs to describe how this transformation should handle these new primitives, and this is exactly what you've done by adding some rules to jax2tf.tf_impl.

Indeed the "finalisation" I described above is basically just another transformation, with its own set of rules. (I went for this approach rather than defining jax2tf rules as it's a bit easier for me to maintain, as it's decoupled from the details of TensorFlow.)

With all of that said -- either way we eventually bump into the issue of tf2onnx not supporting certain operations! I imagine it should be doable to add these to ops to tf2onnx though, e.g. I see that ONNX has an If operator which could probably be used to implement switch.

llandsmeer commented 1 year ago

Removing the wrapper_prims makes the to_onnx function indeed quite a bit cleaner. It seems like tf2onnx is not planning on supporting Match/Case (tf.switch_case) anytime soon, so the tf.switch_case had to be converted to a tf.case.... Not the cleanest solution but it seems to produce working code :)

import numpy as np
import onnxruntime as rt

def patch_jax2tf_with_case_instead_of_switch_case():
    import tensorflow as tf
    import jax
    from jax.experimental.jax2tf import jax2tf
    def _cond(index, *operands, branches, linear):
        del linear
        return tf.case([
              (tf.equal(i, index), lambda: jax2tf._interpret_jaxpr(
                       jaxpr, *operands, extra_name_stack=f'branch_{i}_fun')
              ) for i, jaxpr in enumerate(branches)
        ], exclusive=True)
    jax2tf.tf_impl[jax.lax.cond_p] = _cond

def build_onnx_model():
    import equinox.internal as eqxi
    from diffrax import diffeqsolve, ODETerm, Euler
    def simulate(y0):
        solution = diffeqsolve(
                terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
                t0=0, t1=1, dt0=0.1, y0=y0
                )
        return solution.ys[0]
    onnx_generator_fn = eqxi.to_onnx(simulate)
    model, _none = onnx_generator_fn(1.0)
    return model

patch_jax2tf_with_case_instead_of_switch_case()

onnx_model = build_onnx_model()

sess = rt.InferenceSession(onnx_model.SerializeToString())
input_name = sess.get_inputs()[0].name
onnx_output = sess.run(None, {input_name: np.array(100.0).astype('float32')})[0]
print(onnx_output)

assert np.isclose(onnx_output, 100 * np.exp(-1), rtol=0.1, atol=0.1)

Again, thank you very much for putting all the efforts in this project and quick issue resolution!

patrick-kidger commented 1 year ago

Nice! This is cool to see -- and I'm glad you managed to get things working.