Open Artur-Galstyan opened 10 months ago
I've actually encountered something very similar before, in which I needed the first MVP (the second wouldn't work). My solution was to make a sort of filter_scan
and this worked for me. Here is an example using your code of what I am talking about:
import jax
import jax.numpy as jnp
import equinox as eqx
class SimpleMLP(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, *, key) -> None:
self.mlp = eqx.nn.MLP(in_size=3, out_size=1, width_size=32, depth=2, key=key)
def __call__(self, x):
return self.mlp(x)
key = jax.random.PRNGKey(42)
mlp = SimpleMLP(key=key)
def rollout(mlp, xs):
arr, static = eqx.partition(mlp, eqx.is_inexact_array)
def step(carry, x):
mlp = eqx.combine(carry, static) # just for understanding
val = mlp(x)
carry, _ = eqx.partition(mlp, eqx.is_inexact_array)
return carry, [val]
_, scan_out = jax.lax.scan(
step,
arr,
xs
)
return scan_out
key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))
print(vals)
and this does execute.
Right. See #446, which was another example of this issue.
Note that this is completely unrelated to the use of eqx.Module
, though. (Modules are just PyTrees like any other!) The issue here is that your PyTree has a non-array in it (the activation function of your MLP), and jax.lax.scan
requires that its carry be specifically a PyTree-of-arrays, not anything else.
Ah I see, that makes perfect sense. Thanks!
@patrick-kidger Correct me if I'm wrong, but shouldn't the non-array be marked as static (with respect to the Pytree
structure), and work correctly with jax.lax.scan
automatically?
If the non-array happens to be marked as static, then yes! But there's no requirement that all non-arrays be static fields, and it can happen that non-arrays are in the leaves of the PyTree as well.
In this case, the activation function of eqx.nn.MLP
is not static, so this issue arises. (Why is it not static? Because one might have learnt activation functions, e.g. eqx.nn.PReLU
. We can only mark a field as static if we're certain it will never have any arrays in.)
But if it is learnt, it holds parameters -- and if it holds parameters, it should be an eqx.Module
, no?
I'm saying, if one is careful, and eqx.Module
is used recursively throughout a code base -- this bug should never occur.
I think the problem here is that the activation function is not an eqx.Module
. And, I think, if we're going to be extreme (and I think it's worth it to avoid situations like this one) -- all activation functions should be lifted to Pytree
registration, even if they are just eqx.Module
with a single field marked static
("just a function").
But if it is learnt, it holds parameters -- and if it holds parameters, it should be an eqx.Module, no?
Nope. Modules are just PyTrees like tuples/lists/etc, they have no special behaviour beyond that. (The fact they are also custom classes is what means we can put methods on them, i.e. for a forward pass. That's unrelated.)
Sure, I know how Pytrees work, and I know how abstract base classes that register subclasses as Pytrees work. The statement I was making was just that, if something has parameters that are learned — it should probably be implemented as a eqx.Module
I don’t think that answers my assertion: why not explicitly make any activation callable which is exported by eqx a Module by default?
Yes, someone can still run afoul of this issue if they use their own callable in an architecture that they make — but it seemingly removes the issue for all equinox activation functions, no?
So the only activation function provided by Equinox is eqx.nn.PReLU
, and this is indeed a Module
. I think we already meet the requirement you're asking for, or do I misunderstand you?
No that’s right!
But presumably it’s possible that layers use other ones from e.g. jax.nn — so I just meant exposing wrappers that explicitly mark “the code” (the function) as static in a Module, and expose the Module as a public symbol from eqx (and use it internally for layers, so that “everything is accounted for” as Pytrees).
But I may have misunderstood the issue!
Ah, I see!
I'd prefer not to create wrappers like this, to be honest. It's important that Equinox be directly compatible with JAX -- we don't really want to sit as a layer on top of it. The central thesis here is that Equinox is not a JAX framework, and that it only be a JAX library instead.
Hi,
not sure if this is really a bug or intended, but it's not possible to pass an
eqx.Module
as carry in ajax.lax.scan
. Here is the MVP:This leads to this error
On the other hand, I could just use this (MVP 2):
In MVP 2, I'm using the
mlp
from the outer function inside thescan
. While this works, there could be scenarios in which I update the PyTree inside the scan and since no shapes are changed, it made me think that it would be allowed. I don't necessarily want to change the "outer"mlp
from inside the scan function as it's not directly a part of the function (I don't want to change some global states!).But as already mentioned, I'm not sure if this really is a bug or not.