Closed DigitalPig closed 3 weeks ago
Filtering (conceptually) parses the pytree and removes any nodes of the pytree that are types that jax doesn't play well with (e.g. functions, strings, etc.). In the CNN example, it would probably fail without the filter because of the functions in the list. The reason the RNN doesn't fail/need to filter is because it doesn't have any types elements that are invalid/non jax compatible types.
Got it! Thank you for the answer. I may need a little bit more help here. Which function in the CNN case that causes the issue?
Here is the CNN model's definition:
class CNN(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
# Standard CNN setup: convolutional layer, followed by flattening,
# with a small MLP on top.
self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu,
jnp.ravel,
eqx.nn.Linear(1728, 512, key=key2),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=key3),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=key4),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
for layer in self.layers:
x = layer(x)
return x
And here is the RNN example:
class RNN(eqx.Module):
hidden_size: int
cell: eqx.Module
linear: eqx.nn.Linear
bias: jax.Array
def __init__(self, in_size, out_size, hidden_size, *, key):
ckey, lkey = jrandom.split(key)
self.hidden_size = hidden_size
self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
self.bias = jnp.zeros(out_size)
def __call__(self, input):
hidden = jnp.zeros((self.hidden_size,))
def f(carry, inp):
return self.cell(inp, carry), None
out, _ = lax.scan(f, hidden, input)
# sigmoid because we're performing binary classification
return jax.nn.sigmoid(self.linear(out) + self.bias)
I assume you mean the self.layers
list in the CNN model is the reason we need the filter?
Also, what would be the best practice/recommended route? Should I do filter all the time? Or should I throw the model in and do the filter only when I see errors? Are there performance differences between filtering and non-filtering?
Thank you again for your help!
Optax init is doing some sort of tree parsing to basically create a state that matches the structure of the pytree but has different values at the leaves. If you look at the leaves of the RNN, you see only numeric types which optax can understand. But if you look at the leaves of the CNN you see stuff like <jax._src.custom_derivatives.custom_jvp at 0x7855dd5bc880>
which then causes an error in optax because of the functions: TypeError: zeros_like requires ndarray or scalar arguments, got <class 'function'> at position 0.
. In the CNN things like relu are registered in the pytree, which are the functions of issue. If you just called them during the __call__
function, they wouldn't be a part of the pytree and thus wouldn't error for optax.
I usually just filter for optax, because I'm certain I don't want optax to see anything that isn't an array (other functions I might think more about). The performance here (in most cases) isn't something I even think about here since this is a one time step that is a tiny percentage of runtime (the training loop takes much more time). In general, filtering/partitioning does have an overhead (that was discussed some here: https://github.com/patrick-kidger/equinox/issues/824).
Thank you for the comprehensive answer! Very helpful. Going to close this issue for now. Thanks again for your time!
Hi. Thanks for the awesome package. It does make the transition from pytorch to jax much simpler.
I was following this RNN example here and realize that in cell no. 4, the initialization of optimization state is defined as:
However in the CNN example, the optimization is defined as:
I feel like we should do the same filter in the RNN too? Surprisingly I don't get any error when not filtering in the RNN though, which itself is a question that I don't know why. :)
Thank you for your time!