patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Replicating jax's export feature to cache filter_jitted functions #879

Open aeftimia opened 1 month ago

aeftimia commented 1 month ago

Jax allows you to serialize and deserialize a jitted function as described here.

https://jax.readthedocs.io/en/latest/_autosummary/jax.export.export.html#jax.export.export

I tried this for a filter_jitted function, but received this error.

Function to be exported must be the result of `jit` but is: _JitWrapper(
  fn='simulate',
  filter_warning=False,
  donate_first=False,
  donate_rest=False
)%

Is it possible to replicate serialization like this for filter_jit? I wasn't sure if that was even theoretically possible, but figured it was at least worth asking.

patrick-kidger commented 1 month ago

This should definitely be possible, we'd just have to write an API that wraps the existing jax.export.{export, deserialize}.

Under-the-hood eqx.filter_jit is basically just wrapping jax.jit with nicer behaviour. We could arrange to unwrap that, run jax.export.export, and then save any additional metadata. And conversely an analogous filter_deserialize function could then read the result, and package things back up.

I'd be happy to take a PR on this.

aeftimia commented 1 month ago

When I try

from jax import export

class MyClass(eqx.Module):
  @eqx.filter_jit
  def fn(self, ...):
       ....

model = MyClass(...)
exported = export.export(model.fn._cached)
exported(*args, **kwargs).serialize()

I get this error.

Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jaxlib.xla_extension.ArrayImpl'> for function simulate is non-hashable.%

I'm guessing all the extra stuff that filter_jit turns into static args wouldn't be supported for serialization? Do you think there is any way around that?

aeftimia commented 1 month ago

Just following up with the minimal example to replicate the problem.

class Test(eqx.Module):

    @eqx.filter_jit
    def fn(self, data):
        return data

obj = Test()
fn = obj.fn
x = jnp.array(3.0)
fn(x)
exported = export.export(fn.func._cached)
exported(*fn.args, obj, x, **fn.keywords).serialize()
patrick-kidger commented 1 month ago

Right, so this is because the internal JIT'd function (fn.func._cached) does not have the same signature as the original function.

This is kind of the whole point of filter_jit: it automatically looks at your arguments, then splits them into three groups of 'arguments that should be traced and donated', 'arguments that should be traced and not donated', and 'arguments that are static', and then passes them across the JIT boundary in those groups.

This is the function that's actually JIT'd:

https://github.com/patrick-kidger/equinox/blob/d9b3ffde8903171263fa9fe303f2fdcb242f3bdc/equinox/_jit.py#L43

and here is where they are split up in this way:

https://github.com/patrick-kidger/equinox/blob/d9b3ffde8903171263fa9fe303f2fdcb242f3bdc/equinox/_jit.py#L220-L222