patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

eqx.Modules cannot be passed as carry into jax.lax.scan #630

Open Artur-Galstyan opened 10 months ago

Artur-Galstyan commented 10 months ago

Hi,

not sure if this is really a bug or intended, but it's not possible to pass an eqx.Module as carry in a jax.lax.scan. Here is the MVP:

import jax
import jax.numpy as jnp
import equinox as eqx

class SimpleMLP(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, *, key) -> None:
        self.mlp = eqx.nn.MLP(in_size=3, out_size=1, width_size=32, depth=2, key=key)

    def __call__(self, x):
        return self.mlp(x)

key = jax.random.PRNGKey(42)
mlp = SimpleMLP(key=key)

def rollout(mlp, xs):
    def step(carry, x):
        mlp = carry # just for understanding 
        val = mlp(x)
        carry = mlp
        return carry, [val]

    _, scan_out = jax.lax.scan(
        step,
        [mlp],
        xs
    )

    return scan_out

key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))

This leads to this error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[79], line 32
     29     return scan_out
     31 key, subkey = jax.random.split(key)
---> 32 vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))

Cell In[79], line 23, in rollout(mlp, xs)
     20     carry = mlp
     21     return carry, [val]
---> 23 _, scan_out = jax.lax.scan(
     24     step,
     25     [mlp],
     26     xs
     27 )
     29 return scan_out

    [... skipping hidden 5 frame]

File ~/Workspace/jaxRL/.venv/lib/python3.11/site-packages/jax/_src/core.py:1423, in concrete_aval(x)
   1421 if hasattr(x, '__jax_array__'):
   1422   return concrete_aval(x.__jax_array__())
-> 1423 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
   1424                  "type")

TypeError: Value <jax._src.custom_derivatives.custom_jvp object at 0x10f5846d0> with type <class 'jax._src.custom_derivatives.custom_jvp'> is not a valid JAX type

On the other hand, I could just use this (MVP 2):

import jax
import jax.numpy as jnp
import equinox as eqx

class SimpleMLP(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, *, key) -> None:
        self.mlp = eqx.nn.MLP(in_size=3, out_size=1, width_size=32, depth=2, key=key)

    def __call__(self, x):
        return self.mlp(x)

key = jax.random.PRNGKey(42)
mlp = SimpleMLP(key=key)

def rollout(mlp, xs):
    def step(carry, x):
        val = mlp(x)
        return carry, [val]

    _, scan_out = jax.lax.scan(
        step,
        [],
        xs
    )

    return scan_out

key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))

In MVP 2, I'm using the mlp from the outer function inside the scan. While this works, there could be scenarios in which I update the PyTree inside the scan and since no shapes are changed, it made me think that it would be allowed. I don't necessarily want to change the "outer" mlp from inside the scan function as it's not directly a part of the function (I don't want to change some global states!).

But as already mentioned, I'm not sure if this really is a bug or not.

lockwo commented 10 months ago

I've actually encountered something very similar before, in which I needed the first MVP (the second wouldn't work). My solution was to make a sort of filter_scan and this worked for me. Here is an example using your code of what I am talking about:

import jax
import jax.numpy as jnp
import equinox as eqx

class SimpleMLP(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, *, key) -> None:
        self.mlp = eqx.nn.MLP(in_size=3, out_size=1, width_size=32, depth=2, key=key)

    def __call__(self, x):
        return self.mlp(x)

key = jax.random.PRNGKey(42)
mlp = SimpleMLP(key=key)

def rollout(mlp, xs):
    arr, static = eqx.partition(mlp, eqx.is_inexact_array)
    def step(carry, x):
        mlp = eqx.combine(carry, static) # just for understanding 
        val = mlp(x)
        carry, _ = eqx.partition(mlp, eqx.is_inexact_array)
        return carry, [val]

    _, scan_out = jax.lax.scan(
        step,
        arr,
        xs
    )

    return scan_out

key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))
print(vals)

and this does execute.

patrick-kidger commented 10 months ago

Right. See #446, which was another example of this issue.

Note that this is completely unrelated to the use of eqx.Module, though. (Modules are just PyTrees like any other!) The issue here is that your PyTree has a non-array in it (the activation function of your MLP), and jax.lax.scan requires that its carry be specifically a PyTree-of-arrays, not anything else.

Artur-Galstyan commented 10 months ago

Ah I see, that makes perfect sense. Thanks!

femtomc commented 10 months ago

@patrick-kidger Correct me if I'm wrong, but shouldn't the non-array be marked as static (with respect to the Pytree structure), and work correctly with jax.lax.scan automatically?

patrick-kidger commented 10 months ago

If the non-array happens to be marked as static, then yes! But there's no requirement that all non-arrays be static fields, and it can happen that non-arrays are in the leaves of the PyTree as well.

In this case, the activation function of eqx.nn.MLP is not static, so this issue arises. (Why is it not static? Because one might have learnt activation functions, e.g. eqx.nn.PReLU. We can only mark a field as static if we're certain it will never have any arrays in.)

femtomc commented 10 months ago

But if it is learnt, it holds parameters -- and if it holds parameters, it should be an eqx.Module, no?

I'm saying, if one is careful, and eqx.Module is used recursively throughout a code base -- this bug should never occur.

femtomc commented 10 months ago

I think the problem here is that the activation function is not an eqx.Module. And, I think, if we're going to be extreme (and I think it's worth it to avoid situations like this one) -- all activation functions should be lifted to Pytree registration, even if they are just eqx.Module with a single field marked static ("just a function").

patrick-kidger commented 10 months ago

But if it is learnt, it holds parameters -- and if it holds parameters, it should be an eqx.Module, no?

Nope. Modules are just PyTrees like tuples/lists/etc, they have no special behaviour beyond that. (The fact they are also custom classes is what means we can put methods on them, i.e. for a forward pass. That's unrelated.)

femtomc commented 10 months ago

Sure, I know how Pytrees work, and I know how abstract base classes that register subclasses as Pytrees work. The statement I was making was just that, if something has parameters that are learned — it should probably be implemented as a eqx.Module

I don’t think that answers my assertion: why not explicitly make any activation callable which is exported by eqx a Module by default?

Yes, someone can still run afoul of this issue if they use their own callable in an architecture that they make — but it seemingly removes the issue for all equinox activation functions, no?

patrick-kidger commented 10 months ago

So the only activation function provided by Equinox is eqx.nn.PReLU, and this is indeed a Module. I think we already meet the requirement you're asking for, or do I misunderstand you?

femtomc commented 10 months ago

No that’s right!

But presumably it’s possible that layers use other ones from e.g. jax.nn — so I just meant exposing wrappers that explicitly mark “the code” (the function) as static in a Module, and expose the Module as a public symbol from eqx (and use it internally for layers, so that “everything is accounted for” as Pytrees).

But I may have misunderstood the issue!

patrick-kidger commented 10 months ago

Ah, I see!

I'd prefer not to create wrappers like this, to be honest. It's important that Equinox be directly compatible with JAX -- we don't really want to sit as a layer on top of it. The central thesis here is that Equinox is not a JAX framework, and that it only be a JAX library instead.