This is a similar issue to #857, it looks like there is a bad interaction between InitVar and pickle. Better practices than resorting to pickle are given in the documentation and I have been able to solve the issue and find a better workaround. But I think I should open the issue for the record.
Pickling model parameters containing an InitVar breaks PyTreeDef equality:
import pickle
from dataclasses import InitVar
from copy import deepcopy
import jax
import equinox as eqx
from jaxtyping import Key
class MLP2(eqx.Module):
key: InitVar[Key] = eqx.field(kw_only=True)
layers: list = eqx.field(init=False)
def __post_init__(self, key):
self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]
def __call__(self, t):
for layer in self.layers:
t = layer(t)
return t
key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)
with open("parameters.pkl", "wb") as f:
pickle.dump(params, f)
with open("parameters.pkl", "rb") as f:
reloaded_params = pickle.load(f)
print(jax.tree.flatten(params)[1] == jax.tree.flatten(reloaded_params)[1]) # return False!
The above works by removing the InitVar:
class MLP1(eqx.Module):
layers: list = eqx.field(init=False)
def __init__(self, key):
self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]
def __call__(self, t):
for layer in self.layers:
t = layer(t)
return t
key = jax.random.PRNGKey(0)
mlp = MLP1(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)
with open("parameters.pkl", "wb") as f:
pickle.dump(params, f)
with open("parameters.pkl", "rb") as f:
reloaded_params = pickle.load(f)
print(jax.tree.flatten(params)[1] == jax.tree.flatten(reloaded_params)[1]) # return True
Hi,
This is a similar issue to #857, it looks like there is a bad interaction between
InitVar
andpickle
. Better practices than resorting topickle
are given in the documentation and I have been able to solve the issue and find a better workaround. But I think I should open the issue for the record.Pickling model parameters containing an
InitVar
breaks PyTreeDef equality:The above works by removing the
InitVar
: