patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Bug in how eqx.Module interacts with jax.tree.map #889

Open xaviergonzalez opened 2 weeks ago

xaviergonzalez commented 2 weeks ago

There seems to be a bug in how eqx.Module interacts with jax.tree.map

Here is the repro:

import jax
import equinox as eqx
import jax.numpy as jnp
from flax import linen as nn

class Identity:

    def __init__(self):
      pass

    def __call__(self, x):
        return x

class IdentityEqx(eqx.Module):

    def __init__(self):
      pass

    def __call__(self, x):
        return x

class IdentityFlax(nn.Module):

    def __init__(self):
      pass

    def __call__(self, x):
        return x

# everything is fine without equinox
tst_fxns = [Identity(), Identity()]
tst_states = [jnp.zeros(2), jnp.zeros(2)]
print(jax.tree.map(lambda f,x : f(x), tst_fxns, tst_states))

# everything is fine in flax (inheritance is not a problem)
tst_fxns_flax = [IdentityFlax(), IdentityFlax()]
print(jax.tree.map(lambda f,x : f(x), tst_fxns_flax, tst_states))

# get Custom node type mismatch: expected type: <class '__main__.IdentityEqx'>, value: Array([0., 0.], dtype=float32). with equinox
tst_fxns_eqx = [IdentityEqx(), IdentityEqx()]
print(jax.tree.map(lambda f,x : f(x), tst_fxns_eqx, tst_states))

We would really like to be able to apply different functions across a list or array of inputs. We'd love to do so in equinox especially. The behavior seems like an equinox bug because we don't have this problem when we inherit from flax. Do you have any suggestions for a workaround? Do you know what the nature of this bug is and how it could be fixed?

patrick-kidger commented 2 weeks ago

Swap your final line for:

print(jax.tree.map(lambda x, f : f(x), tst_states, tst_fxns_eqx))

This isn't a bug. Equinox modules are pytrees. jax.tree.map requires that the first tree (in your case tst_fxns_eqx, with structure [IdentityEqx(), IdentityEqx()]) be a prefix of all later trees (in this case tst_states, with structure [*, *]). In particular note that your IdentityEqx class is an 'empty' pytree -- e.g. like an empty list [].

If this seems confusing, here's a simpler non-Equinox equivalent you can study, that will also raise an error:

fn = lambda a, b: None
tree1 = [[], []]
tree2 = ['leaf1', 'leaf2']
jax.tree.map(fn, tree1, tree2)
xaviergonzalez commented 2 weeks ago

Thank you so much for your very helpful and fast reply!