patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 143 forks source link

InitVar and pickle break PyTreeDef equality #859

Open HGangloff opened 2 months ago

HGangloff commented 2 months ago

Hi,

This is a similar issue to #857, it looks like there is a bad interaction between InitVar and pickle. Better practices than resorting to pickle are given in the documentation and I have been able to solve the issue and find a better workaround. But I think I should open the issue for the record.

Pickling model parameters containing an InitVar breaks PyTreeDef equality:

import pickle
from dataclasses import InitVar
from copy import deepcopy

import jax
import equinox as eqx
from jaxtyping import Key

class MLP2(eqx.Module):

    key: InitVar[Key] = eqx.field(kw_only=True)
    layers: list = eqx.field(init=False)

    def __post_init__(self, key):
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t

key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)

with open("parameters.pkl", "wb") as f:
    pickle.dump(params, f)
with open("parameters.pkl", "rb") as f:
    reloaded_params = pickle.load(f)

print(jax.tree.flatten(params)[1] == jax.tree.flatten(reloaded_params)[1]) # return False!

The above works by removing the InitVar:

class MLP1(eqx.Module):

    layers: list = eqx.field(init=False)

    def __init__(self, key):        
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t

key = jax.random.PRNGKey(0)
mlp = MLP1(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)

with open("parameters.pkl", "wb") as f:
    pickle.dump(params, f)
with open("parameters.pkl", "rb") as f:
    reloaded_params = pickle.load(f)

print(jax.tree.flatten(params)[1] == jax.tree.flatten(reloaded_params)[1]) # return True