google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
304 stars 36 forks source link

_validate_params fails on zero-sized arrays #1309

Open hrbigelow opened 2 weeks ago

hrbigelow commented 2 weeks ago

Hi,

@niketkumar @cpgaffney1,

cc @dionhaefner

The following attempts to serialize a zero-sized array, but it fails validation in _validate_params.

I believe the problem is that _validate_params expects to find for every 'foo/.zarray' entry, a matching data entry foo/0. However, this code produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray', but not z/0 since there is no data in the z tensor.

I'm actually not sure if tensorstore saves an entry z/0 or not, or what the intended behavior should be.

Any insight would be greatly appreciated!

import jax.numpy as jnp
import jax.tree_util as jtu
import tempfile
import orbax.checkpoint as ocp

target = {
    'a': jnp.array([1, 2, 3], jnp.int32),
    'z': jnp.zeros((0,)),
}

orbax_checkpointer = ocp.Checkpointer(
  ocp.PyTreeCheckpointHandler()
)

with tempfile.TemporaryDirectory() as ckpt_path:
  overwrite = True
  save_args = jtu.tree_map(lambda _: ocp.SaveArgs(), target)
  orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
(jax_env) henry@henry-gs65:orbax$ python flax4309.py 
Traceback (most recent call last):
  File "/home/henry/ai/projects/orbax/flax4309.py", line 18, in <module>
    orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/checkpointer.py", line 216, in save
    self._handler.finalize(tmpdir.get())
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 1004, in finalize
    self._handler_impl.finalize(directory)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 806, in finalize
    asyncio_utils.run_sync(
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
    return asyncio.run(coro)
           ^^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 190, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 704, in merge_ocdbt_per_process_files
    await _validate_params(directory, ts_context, use_zarr3=use_zarr3)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 625, in _validate_params
    raise ValueError(
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
Tensorstore KvStore: KvStore({
  'base': {
    'driver': 'file',
    'path': '/tmp/tmpbxi1zpec.orbax-checkpoint-tmp-0/',
  },
  'cache_pool': 'cache_pool#ocdbt',
  'config': {
    'compression': {'id': 'zstd'},
    'max_decoded_node_bytes': 100000000,
    'max_inline_value_bytes': 1024,
    'uuid': '3ef941407cca4f778414e9e92b15dedb',
    'version_tree_arity_log2': 4,
  },
  'context': {
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'ocdbt',
  'experimental_read_coalescing_interval': '1ms',
  'experimental_read_coalescing_merged_bytes': 500000000000,
  'experimental_read_coalescing_threshold_bytes': 1000000,
}).
cpgaffney1 commented 2 weeks ago

Thanks for spotting this, 0-sized array handling is not well defined and we have no tests (internal or external) for it. We will clarify the intended behavior, add tests, and resolve the validation issue, and get back to you.