Closed neel04 closed 1 year ago
So EOFError
is often caused by passing in a like=...
argument that doesn't match what you're actually deserialising. For example if you pass in a like=...
with more leaves than the serialised pytree, then it will attempt to deserialise more objects than are actually saved in the file.
As such I suspect it probably is something to do with your opt_state
changing. Whether that's due to EmptyState()
or something else is hard for me to pin down from what you've described.
You can probably debug this just by carefully checking what you're serialising against what you're passing in as like=...
, and looking for any discrepancy.
Thanks! So uhh, wouldn't it become a chicken-and-egg-problem if I have to store/serialise the leaves of opt_state
itself to later use that for de-serialisation?
I ran a few training steps with the debugger, but serialisation/de-serialisation still works. Additionally, if one tries to repro something like that:
import jax
import optax
import equinox as eqx
import jax.numpy as jnp
from copy import deepcopy
from jaxtyping import Array
key = jax.random.PRNGKey(5)
model = eqx.nn.Linear(4, 1, key=key)
input_arr = jnp.ones((4, 4))
output_arr = jnp.zeros(4).astype(int)
# AdamW optimizer with weight decay
optim = optax.chain(
optax.clip(1.0),
optax.adamw(1e-2, weight_decay=1e-4)
)
opt_state_init = optim.init(eqx.filter(model, eqx.is_array))
eqx.tree_serialise_leaves('test.eqx', (model, opt_state_init))
@eqx.filter_value_and_grad
def compute_loss(model: eqx.Module, input_arr: Array, output_arr: Array):
preds = model(input_arr)
loss = jnp.mean((preds - output_arr) ** 2)
return loss
@eqx.filter_jit
def make_step(model: eqx.Module, input_arr: Array, output_arr: Array,
optim, opt_state):
loss, grads = compute_loss(model, input_arr, output_arr)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
opt_state_intermediate = deepcopy(opt_state_init)
for epoch in range(2):
for steps in range(20):
loss, model, opt_state_intermediate = make_step(model, input_arr, output_arr, optim, opt_state_intermediate)
print(loss, model(input_arr))
model, opt_state = eqx.tree_deserialise_leaves('test.eqx', (model, opt_state_intermediate))
It still works.
I guess I can just keep logging opt_state
and see if it ever changes during training?
Anyways, this isn't really an issue with equinox
so I'll close it. Thanks for giving me a direction for debugging!
Thanks! So uhh, wouldn't it become a chicken-and-egg-problem if I have to store/serialise the leaves of opt_state itself to later use that for de-serialisation?
So the usual pattern for de/serialisation is something like:
def run(...hyperparams..., load_from=None):
model, opt_state = setup(...hyperparams...)
if load_from is not None:
model, opt_state = eqx.tree_deserialise_leaves(load_from, (model, opt_state))
... # do training
eqx.tree_serialise_leaves(save_path, (model, opt_state))
where setup
will do whatever's necessary to get pytrees of shapes and structures that don't change.
For what it's worth, thinking about it, I'd be a bit surprised if Optax is changing the optimiser state during training. I've not bumped into that before myself, and moreover that would necessarily imply recompilation of the make_step
function, which would be a bit of a terrible outcome all on its own.
I've made sure to use the exact same hyperparameters. It's definitely odd.
Well, the codebase is definitely efficient and ~5x faster than the torch one (plus >98% GPU util), so I doubt recompilation is taking place.
Which means the PyTree in theory should definitely not change.
This is a mystery.
https://www.diffchecker.com/dHu0xILh/
This is the diff between the opt_state
after a handful of epochs, and the inital opt_state
used as a reference by like
for de-serialisation.
There isn't much of a difference as you can see.
An interesting observation: If I do:
load_eqx_obj(f'{self.save_dir}model_{epoch}.eqx', (model,))
instead of:
load_eqx_obj(f'{self.save_dir}model_{epoch}.eqx', (model, opt_state))
Then the loading works.
Which would imply that in the initial serialisation, the PyTree (model, opt_state)
wasn't being serialised properly - specifically the opt_state
.
What happens if you do optim.init(eqx.filter(model, eqx.is_inexact_array))
?
Huh. That works. Why would that be? Array
s are Array
s no matter their format - no?
Note the distinction between optim.init(eqx.filter(model, eqx.is_array))
and optim.init(eqx.filter(model, eqx.is_inexact_array))
.
You only want Optax to optimise the floating-point arrays, typically. Looking at your diff, it seems that you have a PRNG key (which is an integer array) saved in your pytree, and it seems like Optax is doing something funny when being asked to optimise that.
Impressive you managed to catch that - I'd be tearing my hair off after that.
Why would key
be considered for optimization? Any official way we can just block it for the optimizer? š¤
My model has some random component, but that's all external to the model... hmm..
Why would key be considered for optimization?
When do optim.init(foo)
, you're telling Optax that every leaf in the pytree foo
should be optimised. In this case you have this key as part of your pytree.
Any official way we can just block it for the optimizer?
It's fairly common with most models to (a) have all floating-point arrays be parameters, and (b) everything else be a non-parameter. Thus, optim.init(eqx.filter(foo, eqx.is_inexact_array))
is usually the appropriate way to block this from the optimiser.
This all aside, it's very unusual to need to save a key as part of the model. Keys are normally passed at __init__
time to initialise parameters, or at __call__
time to provide the randomness for stochastic layers (e.g. dropout). Neither of those use-cases require saving the key on the model.
When do optim.init(foo), you're telling Optax that every leaf in the pytree foo should be optimised. In this case you have this key as part of your pytree.
Oh, I know that š I meant in what case would a user need to "optimize" the key - and from jax/equinox's perspective, what would that actually do for PRNGKey
under the hood? fiddle with the seed
?
This all aside, it's very unusual to need to save a key as part of the model. Keys are normally passed at init time to initialise parameters, or at call time to provide the randomness for stochastic layers (e.g. dropout). Neither of those use-cases require saving the key on the model.
I construct an attribute self.key
. Perhaps that might be the issue?
in what case would a user need to "optimize" the key
I don't think it's a well defined notion. :D
I construct an attribute self.key. Perhaps that might be the issue?
Why do you have this attribute at all? It doesn't look like it's being used anywhere.
eqx.nn.Dropout
requires key
to be passed for __call__
, which means lots of ugly changes all down each module's __call__
, plug ignoring axes with vmap
etc.
So setting it an attribute re-uses the key
and thus doesn't require it being passed for __call__
as we just access the attribute.
I take there's probably a neater way to do that?
That's probably not what you want. It means you'll get the exact same dropout mask every time. In JAX, all randomness is a deterministic function of the PRNG key you provide.
Cool, then I'll just pre-generate an array of PRNGKey's every epoch then vmap
across that axis.
Thanks for all your help! have a great day š
Hi, its me again š
This is rather a bizarre problem that I haven't been able to triangulate, let alone reproduce.
Effectively, I'm serialising both my
model
and optimizer_state in a single PyTree:eqx.tree_serialise_leaves(..., (model, opt_state))
But later on, when I want to resume my model's training - I do the standard:
Which leads to:
I'm using
AdamW
offered byoptax
. I noticed some interesting behavior:On Colab, this message is replaced by one which complains that the underlying
jax.load
hasallow_pickle
set toFalse
. But ofc, manually overriding that in the package still doesn't do much as the serialised object is not apickle
file.the
opt_state
is atuple
which contains anEmptyState()
and the other index holds theScaleByAdamState
+ schedule state pytree (as well as the an extraEmptyState()
).I mention this because as we call the optimiser for training, those placeholder
EmptyState()
hold some other PyTrees as we go on, and eventually if they're serialised, they can't be de-serialised becauselike
has an incorrect reference and thus leads to the error. Is that plausible?init
-edopt_state
with a mutated tuple - one simply being(EmptyState(), EmptyState())
then the de-serialisation gives no errors. But the newly loadedopt_state
just hasEmptyState()
s in them.Does that mean
jax
just silently ignored the de-serialisation as its, well already Empty and loaded nothing? Or does that imply that during the serialisation process, the PyTrees were indeed just a couple ofEmptyState()
?This is a bit weird and I can't seem to wrap my head around how I'm supposed to fix this. Would you have any idea how I can proceed debugging this?
Thanks so much again. Sorry for all the questions and trouble š