patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.02k stars 135 forks source link

"Full" `eqx.Module` serialisation #499

Open pfackeldey opened 11 months ago

pfackeldey commented 11 months ago

Dear @patrick-kidger,

this project is awesome, thank you very much for it!!

I am currently developing a fitting library for binned likelihood fits (dilax) in high energy physics based on equinox. These type of fits are typically done in physics analyses at the multi-purpose experiments (ATLAS & CMS) at CERN.

Here it would be highly useful to serialise a full eqx.Module including bound methods. As far as I can see from the equinox-docs serialisation is supported for the leaves of an eqx.Module. Is there a way to serialise a "full" eqx.Module including leaves, but also all bound methods, similar to: https://www.tensorflow.org/tutorials/keras/save_and_load#save_the_entire_model ?

Maybe going even one step further: It would be helpful if you know any way to access, potentially modify, and serialise any JAX computing graph (like a Dask-Graph)?

Thank you very much for your help in advance!

Best, Peter

patrick-kidger commented 11 months ago

Saving just the model's leaves should be enough to also capture all the information you need for the bound methods. Once you have the leaves then you have self, and that's the only subnode of a bound method. So all the information needed to reconstruct a bound method is also present.

In terms of accessing or modifying -- JAX has an internal representation called "jaxprs". You can access this via jax.make_jaxpr or equinox.filter_make_jaxpr, and at this point you are free to modify this in any way you like; it is common to do so with a "jaxpr interpreter"; as an example here is a fairly simple one you can pattern-match from. You can then turn the modified version back into a function using jax.core.jaxpr_as_fun (although note that this is advanced enough stuff that this function isn't technically public API).

As for serialisation. You can arrange to serialise a JAX computation graph as HLO, TF, or ONNX. Which would be most useful for you? (There's no JAX-native format at the moment -- this has been an outstanding request for a while.)

pfackeldey commented 11 months ago

Thank you very much for your reply!

Just to make sure I understand you correctly: I would like to serialise the following eqx.Module, and load it again without the need to have the code of the eqx.Module somewhere, e.g. (pseudo-code):

import jax
import jax.numpy as jnp
import equinox as eqx

class Foo(eqx.Module):
  a: jax.Array
  b: jax.Array

  def __init__(self, a: jax.Array, b: jax.Array) -> None:
    self.a = a
    self.b = b

  @eqx.filter_jit  
  def bar(self) -> jax.Array:
    return self.a + self.b

foo = Foo(a=jnp.array([1.0], b=jnp.array([2.0]))
foo.bar()
# >> Array([3.], dtype=float64)

# some method to serialise
save(foo, "foo.eqx")

del foo

# load it again, without the need to have the code of `Foo`
foo = load("foo.eqx")
foo.bar()
# >> Array([3.], dtype=float64)

Is that possible with equinox's serialisation? At the end of the day I'm looking for these save and load methods somewhere. Sorry if there is an obvious solution which I am currently not seeing...

Thank you so much for the pointer to the code snippet modifying a jaxpr, that is helpful!

I think I would need only the serialisation of the example pseudo-code I've written above, I've generated already the graph of the above defined .bar function as HLO and saved it as txt file (I assume you referred to this, i.e. foo.bar().lower().as_text(). However it is not clear to me how I could run that graph again purely from the txt file.

Thank you for you help! Best, Peter

patrick-kidger commented 11 months ago

So the code for Foo will certainly have to exist somewhere, since that is what is eventually returned by your load function.

But if you'd like to encapsulate it, you certainly can. Take a look at this example. (I'll add a link to this from the serialisation page, it seems this could be better-advertised!)

pfackeldey commented 11 months ago

Thank you very much! Alright, that makes sense to me :)

Just as a side note: I found out that tf.keras.Model.save serializes all functions/classes/... to some well-defined dict-structures, which can be deserialised again only with these dict-structures and tf.keras to actual instances of these objects [1]. That way models can be shared between people without the need to have the code of a particular Model, they only need to have tf.keras installed.

Best, Peter

[1] https://github.com/keras-team/keras/blob/v2.13.1/keras/saving/serialization_lib.py#L401