Open aeftimia opened 1 month ago
This should definitely be possible, we'd just have to write an API that wraps the existing jax.export.{export, deserialize}
.
Under-the-hood eqx.filter_jit
is basically just wrapping jax.jit
with nicer behaviour. We could arrange to unwrap that, run jax.export.export
, and then save any additional metadata. And conversely an analogous filter_deserialize
function could then read the result, and package things back up.
I'd be happy to take a PR on this.
When I try
from jax import export
class MyClass(eqx.Module):
@eqx.filter_jit
def fn(self, ...):
....
model = MyClass(...)
exported = export.export(model.fn._cached)
exported(*args, **kwargs).serialize()
I get this error.
Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jaxlib.xla_extension.ArrayImpl'> for function simulate is non-hashable.%
I'm guessing all the extra stuff that filter_jit turns into static args wouldn't be supported for serialization? Do you think there is any way around that?
Just following up with the minimal example to replicate the problem.
class Test(eqx.Module):
@eqx.filter_jit
def fn(self, data):
return data
obj = Test()
fn = obj.fn
x = jnp.array(3.0)
fn(x)
exported = export.export(fn.func._cached)
exported(*fn.args, obj, x, **fn.keywords).serialize()
Right, so this is because the internal JIT'd function (fn.func._cached
) does not have the same signature as the original function.
This is kind of the whole point of filter_jit
: it automatically looks at your arguments, then splits them into three groups of 'arguments that should be traced and donated', 'arguments that should be traced and not donated', and 'arguments that are static', and then passes them across the JIT boundary in those groups.
This is the function that's actually JIT'd:
and here is where they are split up in this way:
Jax allows you to serialize and deserialize a jitted function as described here.
https://jax.readthedocs.io/en/latest/_autosummary/jax.export.export.html#jax.export.export
I tried this for a filter_jitted function, but received this error.
Is it possible to replicate serialization like this for
filter_jit
? I wasn't sure if that was even theoretically possible, but figured it was at least worth asking.