Closed NeilGirdhar closed 1 month ago
It seems like in a variety of places, Equinox is torn between making modules py-trees (with all static fields marked as static) versus leaving some static fields as dynamic to allow using tree_at
, etc. to modify those fields.
This is seems confusing to me (at least for now). Would it be possible to add a flag to tree_at
so that it crawls over everything it can; and modify the various modules so that they mark their static fields correctly?
I believe StateIndex
is a pytree. It just subclasses eqx.Module
in the usual way.
Looking at your PR it seems like you're trying to avoid having non-arrays in its PyTree structure. I can see that that's a small QoL improvement, which seems reasonable to me.
On your latter point: note that it's not possible to have a module declare its static fields at class definition time. Consider for example eqx.nn.MLP.activation
, which may be a 'static' jax.nn.relu
or a 'dynamic' eqx.nn.PReLU
.
WDYT?
I believe
StateIndex
is a pytree. It just subclasseseqx.Module
in the usual way.
Hmm, I think you're mistaken. The init
field is dynamic, but can be set object()
, which is not a pytree. Therefore, StateIndex
is not a (proper) pytree. That's why the above program crashes.
you're trying to avoid having non-arrays in its PyTree structure.
Not exactly. I'm avoiding have non-pytrees in dynamic parameters. Another alternative would be to use a sentinel that's a pytree. For example,
class Sentinel(eqx.Module):
pass
sentinel = Sentinel() # Use this instead of sentinel = object()
# When checking, you can do if isinstance(init, Sentinel) instead of init is sentinel
On your latter point: note that it's not possible to have a module declare its static fields at class definition time. Consider for example
eqx.nn.MLP.activation
, which may be a 'static'jax.nn.relu
or a 'dynamic'eqx.nn.PReLU
.
A user of eqx.nn.MLP
can just wrap their static value activation (jax.nn.relu
) in a Module
like this:
class JaxActivation(eqx.Module):
f: Callable[[Array], Array] = eqx.field(static=True)
def __call__(self, x: Array, /) -> Array:
return self.f(x)
@jit
def f(m: MLP, /): ...
mlp = MLP(3, 2, activation=JaxActivation(jax.nn.relu))
f(mlp) # Won't crash!
This will ensure that MLP
remains a proper pytree. Equinox's decision to make the field dynamic is correct, in my opinion.
All types are pytrees. JAX is explicit about the fact that even object()
is a valid pytree leaf -- see the first code block here: https://jax.readthedocs.io/en/latest/working-with-pytrees.html The only distinction between types is whether they are pytree nodes or pytree leaves.
Anyway, to ease use of jax.jit
-- which accepts specifically PyTree[ArrayLike]
and not PyTree[Any]
-- then eqx.nn
does generally try to mark known-static fields as static. I'll continue discussion on your PR over there; as I say I think this is a reasonable QoL improvement.
All types are pytrees. JAX is explicit about the fact that even
object()
is a valid pytree leaf
Ah, okay! I'm using the wrong terminology. Let's say "dynamic" then? That is, things that can be passed dynamically to a function decorated by jax.jit. In Flax, they used to call this pytree_like
, which is where I got the terminology.
I say I think this is a reasonable QoL improvement.
Great! Thanks.
Is this a correct usage of stateful programming?
If so, why isn't
StateIndex
a PyTree?