Open xaviergonzalez opened 2 weeks ago
Swap your final line for:
print(jax.tree.map(lambda x, f : f(x), tst_states, tst_fxns_eqx))
This isn't a bug. Equinox modules are pytrees. jax.tree.map
requires that the first tree (in your case tst_fxns_eqx
, with structure [IdentityEqx(), IdentityEqx()]
) be a prefix of all later trees (in this case tst_states
, with structure [*, *]
). In particular note that your IdentityEqx
class is an 'empty' pytree -- e.g. like an empty list []
.
If this seems confusing, here's a simpler non-Equinox equivalent you can study, that will also raise an error:
fn = lambda a, b: None
tree1 = [[], []]
tree2 = ['leaf1', 'leaf2']
jax.tree.map(fn, tree1, tree2)
Thank you so much for your very helpful and fast reply!
There seems to be a bug in how eqx.Module interacts with jax.tree.map
Here is the repro:
We would really like to be able to apply different functions across a list or array of inputs. We'd love to do so in equinox especially. The behavior seems like an equinox bug because we don't have this problem when we inherit from flax. Do you have any suggestions for a workaround? Do you know what the nature of this bug is and how it could be fixed?