patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.06k stars 136 forks source link

InitVar and deepcopy break PyTreeDef equality #857

Open HGangloff opened 1 week ago

HGangloff commented 1 week ago

Hi,

In some optimization process, I want to compare some new parameter values to old ones that I stored using a deepcopy. I get an error in the jitting of the optimization function because of the tree structure of my parameters being modified. See below a MWE, where we lose the tree structure equality, which is the root of the trouble in my complete program.

from dataclasses import InitVar
from copy import deepcopy

import jax
import equinox as eqx
from jaxtyping import Key

class MLP2(eqx.Module):

    key: InitVar[Key] = eqx.field(kw_only=True)
    layers: list = eqx.field(init=False)

    def __post_init__(self, key):
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t

key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)
print(jax.tree.flatten(params)[1] == jax.tree.flatten(deepcopy(params))[1]) # return False!
print(jax.tree.flatten(params)[1], jax.tree.flatten(deepcopy(params))[1])

Note that I found out the bug disappears when not using InitVar (probably less elegant so):

class MLP1(eqx.Module):

    layers: list = eqx.field(init=False)

    def __init__(self, key):        
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t

key = jax.random.PRNGKey(0)
mlp = MLP1(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)
print(jax.tree.flatten(params)[1] == jax.tree.flatten(deepcopy(params))[1]) # returns True!
print(jax.tree.flatten(params)[1], jax.tree.flatten(deepcopy(params))[1])

Is the problem really due to InitVar? Should I use something else rather than deepcopy?

Thanks!

SimonKoop commented 1 week ago

Why do you need a copy of the parameters? All jittable functions should be without side-effects, and jax Arrays are immutable, so you can just store the original array and compare the new array to the old one instead of to some deep copy.

If you really do need to copy arrays, you could use https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.copy.html I guess? So then you could do something like:

from dataclasses import InitVar
import jax
from jax import numpy as jnp
import equinox as eqx
from jaxtyping import Key

class MLP2(eqx.Module):

    key: InitVar[Key] = eqx.field(kw_only=True)
    layers: list = eqx.field(init=False)

    def __post_init__(self, key):
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t

key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)

params_copy = jax.tree.map(lambda x: jnp.copy(x) if isinstance(x, jax.Array) else x, params)

print(jax.tree.flatten(params)[1] == jax.tree.flatten(params_copy)[1]) # return True!
print(jax.tree.flatten(params)[1], jax.tree.flatten(params_copy)[1])
patrick-kidger commented 1 week ago

Hmm. This is really weird! I've poked at this a little bit and you're right, it's specifically the interaction of InitVar[...] and deepcopy. I have no idea why this should be the case.