patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

StateIndex is a Module, but not a PyTree #842

Closed NeilGirdhar closed 1 month ago

NeilGirdhar commented 2 months ago

Is this a correct usage of stateful programming?

import equinox as eqx
import jax.numpy as jnp
from jax import jit

class Model(eqx.Module):
    x: eqx.nn.StateIndex = eqx.field(init=False)

    def __post_init__(self) -> None:
        self.x = eqx.nn.StateIndex(jnp.zeros(()))

@jit
def f(m: Model, state: eqx.nn.State) -> None:
    pass

model, state = eqx.nn.make_with_state(Model)()
f(model, state)

If so, why isn't StateIndex a PyTree?

NeilGirdhar commented 2 months ago

It seems like in a variety of places, Equinox is torn between making modules py-trees (with all static fields marked as static) versus leaving some static fields as dynamic to allow using tree_at, etc. to modify those fields.

This is seems confusing to me (at least for now). Would it be possible to add a flag to tree_at so that it crawls over everything it can; and modify the various modules so that they mark their static fields correctly?

patrick-kidger commented 2 months ago

I believe StateIndex is a pytree. It just subclasses eqx.Module in the usual way.

Looking at your PR it seems like you're trying to avoid having non-arrays in its PyTree structure. I can see that that's a small QoL improvement, which seems reasonable to me.

On your latter point: note that it's not possible to have a module declare its static fields at class definition time. Consider for example eqx.nn.MLP.activation, which may be a 'static' jax.nn.relu or a 'dynamic' eqx.nn.PReLU.

WDYT?

NeilGirdhar commented 2 months ago

I believe StateIndex is a pytree. It just subclasses eqx.Module in the usual way.

Hmm, I think you're mistaken. The init field is dynamic, but can be set object(), which is not a pytree. Therefore, StateIndex is not a (proper) pytree. That's why the above program crashes.

you're trying to avoid having non-arrays in its PyTree structure.

Not exactly. I'm avoiding have non-pytrees in dynamic parameters. Another alternative would be to use a sentinel that's a pytree. For example,

class Sentinel(eqx.Module):
    pass

sentinel = Sentinel()  # Use this instead of sentinel = object()
#  When checking, you can do if isinstance(init, Sentinel) instead of init is sentinel

On your latter point: note that it's not possible to have a module declare its static fields at class definition time. Consider for example eqx.nn.MLP.activation, which may be a 'static' jax.nn.relu or a 'dynamic' eqx.nn.PReLU.

A user of eqx.nn.MLP can just wrap their static value activation (jax.nn.relu) in a Module like this:

class JaxActivation(eqx.Module):
    f: Callable[[Array], Array] = eqx.field(static=True)

    def __call__(self, x: Array, /) -> Array:
        return self.f(x)

@jit
def f(m: MLP, /): ...

mlp = MLP(3, 2, activation=JaxActivation(jax.nn.relu))
f(mlp)  # Won't crash!

This will ensure that MLP remains a proper pytree. Equinox's decision to make the field dynamic is correct, in my opinion.

patrick-kidger commented 2 months ago

All types are pytrees. JAX is explicit about the fact that even object() is a valid pytree leaf -- see the first code block here: https://jax.readthedocs.io/en/latest/working-with-pytrees.html The only distinction between types is whether they are pytree nodes or pytree leaves.

Anyway, to ease use of jax.jit -- which accepts specifically PyTree[ArrayLike] and not PyTree[Any] -- then eqx.nn does generally try to mark known-static fields as static. I'll continue discussion on your PR over there; as I say I think this is a reasonable QoL improvement.

NeilGirdhar commented 2 months ago

All types are pytrees. JAX is explicit about the fact that even object() is a valid pytree leaf

Ah, okay! I'm using the wrong terminology. Let's say "dynamic" then? That is, things that can be passed dynamically to a function decorated by jax.jit. In Flax, they used to call this pytree_like, which is where I got the terminology.

I say I think this is a reasonable QoL improvement.

Great! Thanks.