Closed jeertmans closed 4 weeks ago
Given the preferences you've expressed -- using Equinox as your main NN library -- then I suspect the best thing to do will be something like this:
class Foo(flax.linen.Module): ...
class FlaxToEquinox(eqx.Module):
params: PyTree[Array]
flax_model: flax.linen.Module = eqx.field(static=True)
def __init__(self, flax_model: flax.linen.Module):
self.params = flax_model.init(...)
self.flax_model = flax_model
def __call__(self, ...):
return self.flax_model.apply(self.params, ...)
model = FlaxToEquinox(Foo(...))
and then never have to think about the underlying Flax model at all -- just do Equinox as normal.
Thanks for your quick answer!
I ended up with something like this:
import flax.linen as nn
class FlaxE3MLP(nn.Module):
@nn.compact
def __call__(
self, xyz: Float[Array, "num_points 3"]
) -> Float[Array, "num_points num_features"]:
...
class E3MLP(eqx.Module):
"""Convenient Equinox wrapper to use Flax's modules."""
params: PyTree[Array]
flax_model: FlaxE3MLP = eqx.field(static=True)
def __init__(self, *args: Any, key: PRNGKeyArray, **kwargs: Any) -> None:
self.flax_model = FlaxE3MLP(*args, **kwargs)
# TODO: see if we can do something else than providing a dummy 'xyz' input
self.params = self.flax_model.init(key, jnp.empty((3, 3)))
@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def __call__(
self, xyz: Float[Array, "num_points 3"]
) -> Float[Array, "num_points num_features"]:
return self.flax_model.apply(self.params, xyz)
And I think this works as expected.
However, it seems to me a bit unnatural to me that have to pass "dummy" arguments to init
the Flax model (i.e., with the jnp.empty((3, 3))
).
This is maybe one of the reasons they developed another module, flax.nnx
, where passing random keys feels more natural (to me), but do are you aware if there is a way to work around this dummy initialization? I ask that mainly because xyz
can have arbitrary many num_points
, so it feels weird to pass a fixed number of points for initialization.
Not that I know of, I'm afraid. I think that's just a limitation of Flax.
No issue, you already answered the main question here :)
Thanks!
Hello!
I use Equinox everywhere in my code (especially to create classes that are PyTrees), and creating ML models is pretty simple thanks to your library.
However, there exists plenty of ML "frameworks" for JAX, and each framework has its own pros and cons. I am now facing an "issue", as I want to integrate $\mathrm{E}(3)$-equivariant modules in my code, and, to my best knowledge, the only libraries available are using Flax, e.g., E3x (the "old"
linen
module, notnnx
).I guess I am not the first one trying to mix multiple frameworks, and I was wondering if you had recommendations / knowledge about this?
E.g., should I rather stick with Flax for the ML:
Or is it better to put layers as dataclass attributes, and go the usual "Equinox" way?
I just wanted to know if you had an opinion or experience on that question?