patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Is Equinox compatible with Flax? #886

Closed jeertmans closed 4 weeks ago

jeertmans commented 4 weeks ago

Hello!

I use Equinox everywhere in my code (especially to create classes that are PyTrees), and creating ML models is pretty simple thanks to your library.

However, there exists plenty of ML "frameworks" for JAX, and each framework has its own pros and cons. I am now facing an "issue", as I want to integrate $\mathrm{E}(3)$-equivariant modules in my code, and, to my best knowledge, the only libraries available are using Flax, e.g., E3x (the "old" linen module, not nnx).

I guess I am not the first one trying to mix multiple frameworks, and I was wondering if you had recommendations / knowledge about this?

E.g., should I rather stick with Flax for the ML:

class MyModule(equinox.Module):
  ...

class EquivariantModel(flax.linen.Module):
  features = 8
  max_degree = 2

  @flax.linen.compact
  def __call__(self, x: MyModule):
    ...

Or is it better to put layers as dataclass attributes, and go the usual "Equinox" way?

class Model(equinox.Module):
  layer: e3x.nn.TensorDense

I just wanted to know if you had an opinion or experience on that question?

patrick-kidger commented 4 weeks ago

Given the preferences you've expressed -- using Equinox as your main NN library -- then I suspect the best thing to do will be something like this:

class Foo(flax.linen.Module): ...

class FlaxToEquinox(eqx.Module):
    params: PyTree[Array]
    flax_model: flax.linen.Module = eqx.field(static=True)

    def __init__(self, flax_model: flax.linen.Module):
        self.params = flax_model.init(...)
        self.flax_model = flax_model

    def __call__(self, ...):
        return self.flax_model.apply(self.params, ...)

model = FlaxToEquinox(Foo(...))

and then never have to think about the underlying Flax model at all -- just do Equinox as normal.

jeertmans commented 4 weeks ago

Thanks for your quick answer!

I ended up with something like this:

import flax.linen as nn

class FlaxE3MLP(nn.Module):
    @nn.compact
    def __call__(
        self, xyz: Float[Array, "num_points 3"]
    ) -> Float[Array, "num_points num_features"]:
        ...

class E3MLP(eqx.Module):
    """Convenient Equinox wrapper to use Flax's modules."""

    params: PyTree[Array]
    flax_model: FlaxE3MLP = eqx.field(static=True)

    def __init__(self, *args: Any, key: PRNGKeyArray, **kwargs: Any) -> None:
        self.flax_model = FlaxE3MLP(*args, **kwargs)
        # TODO: see if we can do something else than providing a dummy 'xyz' input
        self.params = self.flax_model.init(key, jnp.empty((3, 3)))

    @eqx.filter_jit
    @jaxtyped(typechecker=typechecker)
    def __call__(
        self, xyz: Float[Array, "num_points 3"]
    ) -> Float[Array, "num_points num_features"]:
        return self.flax_model.apply(self.params, xyz)

And I think this works as expected.

However, it seems to me a bit unnatural to me that have to pass "dummy" arguments to init the Flax model (i.e., with the jnp.empty((3, 3))).

This is maybe one of the reasons they developed another module, flax.nnx, where passing random keys feels more natural (to me), but do are you aware if there is a way to work around this dummy initialization? I ask that mainly because xyz can have arbitrary many num_points, so it feels weird to pass a fixed number of points for initialization.

patrick-kidger commented 4 weeks ago

Not that I know of, I'm afraid. I think that's just a limitation of Flax.

jeertmans commented 4 weeks ago

No issue, you already answered the main question here :)

Thanks!