patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

[Question] De-serialising optimizer state yields EOF: No data left in file #498

Closed neel04 closed 1 year ago

neel04 commented 1 year ago

Hi, its me again šŸ˜„

This is rather a bizarre problem that I haven't been able to triangulate, let alone reproduce.

Effectively, I'm serialising both my model and optimizer_state in a single PyTree: eqx.tree_serialise_leaves(..., (model, opt_state))

But later on, when I want to resume my model's training - I do the standard:

eqx.tree_deserialise_leaves(..., (model, opt_state))

Which leads to:

Traceback (most recent call last):
  File "/kaggle/working/ReAct_Jax/ReAct_Jax/train_model.py", line 62, in <module>
    main(key)
  File "/kaggle/working/ReAct_Jax/ReAct_Jax/train_model.py", line 58, in main
    trainer.train(args.epochs, trainloader, truncloader, valloader, testloader)
  File "/kaggle/working/ReAct_Jax/ReAct_Jax/ReAct/utils/trainer.py", line 142, in train
    model, opt_state, epoch_done = self.resume_training(model, opt_state)
  File "/kaggle/working/ReAct_Jax/ReAct_Jax/ReAct/utils/trainer.py", line 130, in resume_training
    model, opt_state = load_eqx_obj(f'{self.save_dir}model_{epoch}.eqx', (model, opt_state))
  File "/kaggle/working/ReAct_Jax/ReAct_Jax/ReAct/utils/helpers.py", line 17, in load_eqx_obj
    return eqx.tree_deserialise_leaves(path_or_file=filepath,
  File "/opt/conda/lib/python3.10/site-packages/equinox/_serialisation.py", line 270, in tree_deserialise_leaves
    out = _ordered_tree_map(_deserialise, filter_spec, like)
  File "/opt/conda/lib/python3.10/site-packages/equinox/_serialisation.py", line 25, in _ordered_tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/opt/conda/lib/python3.10/site-packages/equinox/_serialisation.py", line 25, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/opt/conda/lib/python3.10/site-packages/equinox/_serialisation.py", line 268, in _deserialise
    return _ordered_tree_map(__deserialise, x, is_leaf=is_leaf)
  File "/opt/conda/lib/python3.10/site-packages/equinox/_serialisation.py", line 25, in _ordered_tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/opt/conda/lib/python3.10/site-packages/equinox/_serialisation.py", line 25, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/opt/conda/lib/python3.10/site-packages/equinox/_serialisation.py", line 266, in __deserialise
    return spec(f, y)
  File "/opt/conda/lib/python3.10/site-packages/equinox/_serialisation.py", line 103, in default_deserialise_filter_spec
    return jnp.load(f)
  File "/opt/conda/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 288, in load
    out = np.load(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/numpy/lib/npyio.py", line 436, in load
    raise EOFError("No data left in file")
EOFError: No data left in file

I'm using AdamW offered by optax. I noticed some interesting behavior:

  1. On Colab, this message is replaced by one which complains that the underlying jax.load has allow_pickle set to False. But ofc, manually overriding that in the package still doesn't do much as the serialised object is not a pickle file.

  2. the opt_state is a tuple which contains an EmptyState() and the other index holds the ScaleByAdamState + schedule state pytree (as well as the an extra EmptyState()).

I mention this because as we call the optimiser for training, those placeholder EmptyState() hold some other PyTrees as we go on, and eventually if they're serialised, they can't be de-serialised because like has an incorrect reference and thus leads to the error. Is that plausible?

  1. If resuming training, I swap out the freshly init-ed opt_state with a mutated tuple - one simply being (EmptyState(), EmptyState()) then the de-serialisation gives no errors. But the newly loaded opt_state just has EmptyState()s in them.

Does that mean jax just silently ignored the de-serialisation as its, well already Empty and loaded nothing? Or does that imply that during the serialisation process, the PyTrees were indeed just a couple of EmptyState()?


This is a bit weird and I can't seem to wrap my head around how I'm supposed to fix this. Would you have any idea how I can proceed debugging this?

Thanks so much again. Sorry for all the questions and trouble šŸ™ƒ

patrick-kidger commented 1 year ago

So EOFError is often caused by passing in a like=... argument that doesn't match what you're actually deserialising. For example if you pass in a like=... with more leaves than the serialised pytree, then it will attempt to deserialise more objects than are actually saved in the file.

As such I suspect it probably is something to do with your opt_state changing. Whether that's due to EmptyState() or something else is hard for me to pin down from what you've described.

You can probably debug this just by carefully checking what you're serialising against what you're passing in as like=..., and looking for any discrepancy.

neel04 commented 1 year ago

Thanks! So uhh, wouldn't it become a chicken-and-egg-problem if I have to store/serialise the leaves of opt_state itself to later use that for de-serialisation?

I ran a few training steps with the debugger, but serialisation/de-serialisation still works. Additionally, if one tries to repro something like that:

import jax
import optax
import equinox as eqx
import jax.numpy as jnp

from copy import deepcopy
from jaxtyping import Array

key = jax.random.PRNGKey(5)
model = eqx.nn.Linear(4, 1, key=key)

input_arr = jnp.ones((4, 4))
output_arr = jnp.zeros(4).astype(int)

# AdamW optimizer with weight decay
optim = optax.chain(
    optax.clip(1.0),
    optax.adamw(1e-2, weight_decay=1e-4)
)

opt_state_init = optim.init(eqx.filter(model, eqx.is_array))
eqx.tree_serialise_leaves('test.eqx', (model, opt_state_init))

@eqx.filter_value_and_grad
def compute_loss(model: eqx.Module, input_arr: Array, output_arr: Array):
    preds = model(input_arr)
    loss = jnp.mean((preds - output_arr) ** 2)

    return loss

@eqx.filter_jit
def make_step(model: eqx.Module, input_arr: Array, output_arr: Array,
              optim, opt_state):

    loss, grads = compute_loss(model, input_arr, output_arr)
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

opt_state_intermediate = deepcopy(opt_state_init)

for epoch in range(2):
    for steps in range(20):
        loss, model, opt_state_intermediate = make_step(model, input_arr, output_arr, optim, opt_state_intermediate)
        print(loss, model(input_arr))

model, opt_state = eqx.tree_deserialise_leaves('test.eqx', (model, opt_state_intermediate))

It still works.

I guess I can just keep logging opt_state and see if it ever changes during training?

Anyways, this isn't really an issue with equinox so I'll close it. Thanks for giving me a direction for debugging!

patrick-kidger commented 1 year ago

Thanks! So uhh, wouldn't it become a chicken-and-egg-problem if I have to store/serialise the leaves of opt_state itself to later use that for de-serialisation?

So the usual pattern for de/serialisation is something like:

def run(...hyperparams..., load_from=None):
    model, opt_state = setup(...hyperparams...)
    if load_from is not None:
        model, opt_state = eqx.tree_deserialise_leaves(load_from, (model, opt_state))
    ... # do training
    eqx.tree_serialise_leaves(save_path, (model, opt_state))

where setup will do whatever's necessary to get pytrees of shapes and structures that don't change.

For what it's worth, thinking about it, I'd be a bit surprised if Optax is changing the optimiser state during training. I've not bumped into that before myself, and moreover that would necessarily imply recompilation of the make_step function, which would be a bit of a terrible outcome all on its own.

neel04 commented 1 year ago

I've made sure to use the exact same hyperparameters. It's definitely odd.

Well, the codebase is definitely efficient and ~5x faster than the torch one (plus >98% GPU util), so I doubt recompilation is taking place.

Which means the PyTree in theory should definitely not change.

This is a mystery.

neel04 commented 1 year ago

https://www.diffchecker.com/dHu0xILh/

This is the diff between the opt_state after a handful of epochs, and the inital opt_state used as a reference by like for de-serialisation.

There isn't much of a difference as you can see.

An interesting observation: If I do:

load_eqx_obj(f'{self.save_dir}model_{epoch}.eqx', (model,))

instead of:

load_eqx_obj(f'{self.save_dir}model_{epoch}.eqx', (model, opt_state))

Then the loading works.

Which would imply that in the initial serialisation, the PyTree (model, opt_state) wasn't being serialised properly - specifically the opt_state.

patrick-kidger commented 1 year ago

What happens if you do optim.init(eqx.filter(model, eqx.is_inexact_array))?

neel04 commented 1 year ago

Huh. That works. Why would that be? Arrays are Arrays no matter their format - no?

patrick-kidger commented 1 year ago

Note the distinction between optim.init(eqx.filter(model, eqx.is_array)) and optim.init(eqx.filter(model, eqx.is_inexact_array)).

You only want Optax to optimise the floating-point arrays, typically. Looking at your diff, it seems that you have a PRNG key (which is an integer array) saved in your pytree, and it seems like Optax is doing something funny when being asked to optimise that.

neel04 commented 1 year ago

Impressive you managed to catch that - I'd be tearing my hair off after that.

Why would key be considered for optimization? Any official way we can just block it for the optimizer? šŸ¤”

My model has some random component, but that's all external to the model... hmm..

patrick-kidger commented 1 year ago

Why would key be considered for optimization?

When do optim.init(foo), you're telling Optax that every leaf in the pytree foo should be optimised. In this case you have this key as part of your pytree.

Any official way we can just block it for the optimizer?

It's fairly common with most models to (a) have all floating-point arrays be parameters, and (b) everything else be a non-parameter. Thus, optim.init(eqx.filter(foo, eqx.is_inexact_array)) is usually the appropriate way to block this from the optimiser.


This all aside, it's very unusual to need to save a key as part of the model. Keys are normally passed at __init__ time to initialise parameters, or at __call__ time to provide the randomness for stochastic layers (e.g. dropout). Neither of those use-cases require saving the key on the model.

neel04 commented 1 year ago

When do optim.init(foo), you're telling Optax that every leaf in the pytree foo should be optimised. In this case you have this key as part of your pytree.

Oh, I know that šŸ˜„ I meant in what case would a user need to "optimize" the key - and from jax/equinox's perspective, what would that actually do for PRNGKey under the hood? fiddle with the seed?

This all aside, it's very unusual to need to save a key as part of the model. Keys are normally passed at init time to initialise parameters, or at call time to provide the randomness for stochastic layers (e.g. dropout). Neither of those use-cases require saving the key on the model.

I construct an attribute self.key. Perhaps that might be the issue?

patrick-kidger commented 1 year ago

in what case would a user need to "optimize" the key

I don't think it's a well defined notion. :D

I construct an attribute self.key. Perhaps that might be the issue?

Why do you have this attribute at all? It doesn't look like it's being used anywhere.

neel04 commented 1 year ago

eqx.nn.Dropout requires key to be passed for __call__, which means lots of ugly changes all down each module's __call__, plug ignoring axes with vmap etc.

So setting it an attribute re-uses the key and thus doesn't require it being passed for __call__ as we just access the attribute.

I take there's probably a neater way to do that?

patrick-kidger commented 1 year ago

That's probably not what you want. It means you'll get the exact same dropout mask every time. In JAX, all randomness is a deterministic function of the PRNG key you provide.

neel04 commented 1 year ago

Cool, then I'll just pre-generate an array of PRNGKey's every epoch then vmap across that axis.

Thanks for all your help! have a great day šŸš€