patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

A potential typo? Or should we do `eqx.filter(model, eqx.is_array)` here in the RNN example? #887

Closed DigitalPig closed 3 weeks ago

DigitalPig commented 3 weeks ago

Hi. Thanks for the awesome package. It does make the transition from pytorch to jax much simpler.

I was following this RNN example here and realize that in cell no. 4, the initialization of optimization state is defined as:

opt_state = optim.init(model)

However in the CNN example, the optimization is defined as:

opt_state = optim.init(eqx.filter(model, eqx.is_array))

I feel like we should do the same filter in the RNN too? Surprisingly I don't get any error when not filtering in the RNN though, which itself is a question that I don't know why. :)

Thank you for your time!

lockwo commented 3 weeks ago

Filtering (conceptually) parses the pytree and removes any nodes of the pytree that are types that jax doesn't play well with (e.g. functions, strings, etc.). In the CNN example, it would probably fail without the filter because of the functions in the list. The reason the RNN doesn't fail/need to filter is because it doesn't have any types elements that are invalid/non jax compatible types.

DigitalPig commented 3 weeks ago

Got it! Thank you for the answer. I may need a little bit more help here. Which function in the CNN case that causes the issue?

Here is the CNN model's definition:

class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x

And here is the RNN example:

class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, *, key):
        ckey, lkey = jrandom.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, input):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            return self.cell(inp, carry), None

        out, _ = lax.scan(f, hidden, input)
        # sigmoid because we're performing binary classification
        return jax.nn.sigmoid(self.linear(out) + self.bias)

I assume you mean the self.layers list in the CNN model is the reason we need the filter?

Also, what would be the best practice/recommended route? Should I do filter all the time? Or should I throw the model in and do the filter only when I see errors? Are there performance differences between filtering and non-filtering?

Thank you again for your help!

lockwo commented 3 weeks ago

Optax init is doing some sort of tree parsing to basically create a state that matches the structure of the pytree but has different values at the leaves. If you look at the leaves of the RNN, you see only numeric types which optax can understand. But if you look at the leaves of the CNN you see stuff like <jax._src.custom_derivatives.custom_jvp at 0x7855dd5bc880> which then causes an error in optax because of the functions: TypeError: zeros_like requires ndarray or scalar arguments, got <class 'function'> at position 0.. In the CNN things like relu are registered in the pytree, which are the functions of issue. If you just called them during the __call__ function, they wouldn't be a part of the pytree and thus wouldn't error for optax.

I usually just filter for optax, because I'm certain I don't want optax to see anything that isn't an array (other functions I might think more about). The performance here (in most cases) isn't something I even think about here since this is a one time step that is a tiny percentage of runtime (the training loop takes much more time). In general, filtering/partitioning does have an overhead (that was discussed some here: https://github.com/patrick-kidger/equinox/issues/824).

DigitalPig commented 3 weeks ago

Thank you for the comprehensive answer! Very helpful. Going to close this issue for now. Thanks again for your time!