google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.14k stars 645 forks source link

Can't save checkpoint with orbax when using zero-size parameters #4309

Open dionhaefner opened 4 weeks ago

dionhaefner commented 4 weeks ago

Trying to save a checkpoint when there are zero-size variables present raises an exception. Used to work fine pre-orbax. (This is part of a bigger model that has conditional logic where some of the variables are unused in certain configurations.)

Reproducer:

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training.checkpoints import save_checkpoint

class DummyModule(nn.Module):
    @nn.compact
    def __call__(self, x):
        var = self.variable('batch_stats', 'var', lambda: jnp.zeros((0,)))
        return x

state = DummyModule().init(jax.random.PRNGKey(0), jnp.zeros((1,)))
save_checkpoint('/tmp/foo', state, 1)

This prints:

$ python flax_bug_repro.py
Traceback (most recent call last):
  File "/Users/dion/codes/supersede/flax_bug_repro.py", line 17, in <module>
    save_checkpoint('/tmp/foo', state, 1)
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/flax/training/checkpoints.py", line 694, in save_checkpoint
    orbax_checkpointer.save(
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 216, in save
    self._handler.finalize(tmpdir.get())
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 998, in finalize
    self._handler_impl.finalize(directory)
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 782, in finalize
    asyncio.run(
  File "/opt/homebrew/Cellar/python@3.10/3.10.13/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/homebrew/Cellar/python@3.10/3.10.13/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 657, in merge_ocdbt_per_process_files
    await _validate_params(ts_kv_store, use_zarr3=use_zarr3)
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 574, in _validate_params
    raise ValueError(
ValueError: Save failed: 1/1 params are missing in checkpoint:
batch_stats.var.
Tensorstore KvStore: KvStore({
  'base': {
    'driver': 'file',
    'path': '/tmp/foo/checkpoint_1.orbax-checkpoint-tmp-0/',
  },
  'cache_pool': 'cache_pool#ocdbt',
  'config': {
    'compression': {'id': 'zstd'},
    'max_decoded_node_bytes': 100000000,
    'max_inline_value_bytes': 1024,
    'uuid': '1dc83c4d929da7f10e13b4bd3f592ccd',
    'version_tree_arity_log2': 4,
  },
  'context': {
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'ocdbt',
  'experimental_read_coalescing_interval': '1ms',
  'experimental_read_coalescing_merged_bytes': 500000000000,
  'experimental_read_coalescing_threshold_bytes': 1000000,
}).

System information

Flax==0.10.0
orbax==0.1.9
hrbigelow commented 3 weeks ago

This seems an orbax issue rather than flax. Looks like a recent change to Orbax assumes each checkpoint entry with a '.zarray' should have at least one entry without. Relevant function is _validate_params

Image

For instance:

import jax, jax.numpy as jnp
from flax.training import checkpoints
import tempfile

with tempfile.TemporaryDirectory() as dir_path:
  test_object = {
    'a': jnp.array([1, 2, 3], jnp.int32),
    'z': jnp.zeros((0,)),
  }
  file_path = checkpoints.save_checkpoint(
    dir_path, target=test_object, step=0, prefix='test_', keep=1
  )
  restored_object = checkpoints.restore_checkpoint(
    file_path, target=None
  )

print(restored_object)
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
...

Produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray', but not z/0 since there is no data in the z tensor.

dionhaefner commented 1 week ago

Sooo should I take this up with the orbax people or are you already in contact?

hrbigelow commented 1 week ago

Hi @IvyZX do you mind if I take a look and try to solve this on the Orbax side?

EDIT: @dionhaefner I opened an orbax issue 1309 for this. It's a bug either in orbax or tensorstore, not flax.