Open dionhaefner opened 4 weeks ago
This seems an orbax issue rather than flax. Looks like a recent change to Orbax assumes each checkpoint entry with a '.zarray' should have at least one entry without. Relevant function is _validate_params
For instance:
import jax, jax.numpy as jnp
from flax.training import checkpoints
import tempfile
with tempfile.TemporaryDirectory() as dir_path:
test_object = {
'a': jnp.array([1, 2, 3], jnp.int32),
'z': jnp.zeros((0,)),
}
file_path = checkpoints.save_checkpoint(
dir_path, target=test_object, step=0, prefix='test_', keep=1
)
restored_object = checkpoints.restore_checkpoint(
file_path, target=None
)
print(restored_object)
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
...
Produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray'
, but not z/0
since there is no data in the z tensor.
Sooo should I take this up with the orbax people or are you already in contact?
Hi @IvyZX do you mind if I take a look and try to solve this on the Orbax side?
EDIT: @dionhaefner I opened an orbax issue 1309 for this. It's a bug either in orbax or tensorstore, not flax.
Trying to save a checkpoint when there are zero-size variables present raises an exception. Used to work fine pre-orbax. (This is part of a bigger model that has conditional logic where some of the variables are unused in certain configurations.)
Reproducer:
This prints:
System information