Open HGangloff opened 2 months ago
Why do you need a copy of the parameters? All jittable functions should be without side-effects, and jax Arrays are immutable, so you can just store the original array and compare the new array to the old one instead of to some deep copy.
If you really do need to copy arrays, you could use https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.copy.html I guess? So then you could do something like:
from dataclasses import InitVar
import jax
from jax import numpy as jnp
import equinox as eqx
from jaxtyping import Key
class MLP2(eqx.Module):
key: InitVar[Key] = eqx.field(kw_only=True)
layers: list = eqx.field(init=False)
def __post_init__(self, key):
self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]
def __call__(self, t):
for layer in self.layers:
t = layer(t)
return t
key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)
params_copy = jax.tree.map(lambda x: jnp.copy(x) if isinstance(x, jax.Array) else x, params)
print(jax.tree.flatten(params)[1] == jax.tree.flatten(params_copy)[1]) # return True!
print(jax.tree.flatten(params)[1], jax.tree.flatten(params_copy)[1])
Hmm. This is really weird! I've poked at this a little bit and you're right, it's specifically the interaction of InitVar[...]
and deepcopy
.
I have no idea why this should be the case.
Hi,
In some optimization process, I want to compare some new parameter values to old ones that I stored using a deepcopy. I get an error in the jitting of the optimization function because of the tree structure of my parameters being modified. See below a MWE, where we lose the tree structure equality, which is the root of the trouble in my complete program.
Note that I found out the bug disappears when not using
InitVar
(probably less elegant so):Is the problem really due to
InitVar
? Should I use something else rather thandeepcopy
?Thanks!