google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
305 stars 36 forks source link

StandardSave fails with confusing error when given scalars or arrays #1288

Closed garymm closed 3 weeks ago

garymm commented 3 weeks ago
import tempfile

import jax.numpy as jnp
import orbax.checkpoint as ocp

checkpoints_dir = tempfile.mkdtemp()
checkpoint_manager = ocp.CheckpointManager(directory=checkpoints_dir)

a_scalar_array = jnp.array(1) # also fails in the same way with python scalars
save_args = ocp.args.StandardSave(a_scalar_array)
checkpoint_manager.save(0, args=save_args)
checkpoint_manager.wait_until_finished()

Fails with:

ValueError: Save failed: 2/2 params are missing .zarray in checkpoint:
.zarray 
0.
Tensorstore KvStore: KvStore({
  'base': {
    'driver': 'file',
    'path': '/var/folders/6j/w0kq6xj1055g461npm9dr8z80000gn/T/tmpjs142dws/0.orbax-checkpoint-tmp-0/default.orbax-checkpoint-tmp-3/',
  },
  'cache_pool': 'cache_pool#ocdbt',
  'config': {
    'compression': {'id': 'zstd'},
    'max_decoded_node_bytes': 100000000,
    'max_inline_value_bytes': 1024,
    'uuid': '6ebac203e5eaad376e3bd348262366ef',
    'version_tree_arity_log2': 4,
  },
  'context': {
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'ocdbt',
  'experimental_read_coalescing_interval': '1ms',
  'experimental_read_coalescing_merged_bytes': 500000000000,
  'experimental_read_coalescing_threshold_bytes': 1000000,
}).

The work-around I've found is to use ArraySave instead.

cpgaffney1 commented 3 weeks ago

This is WAI, at the moment singular arrays are not handled with StandardSave or PyTreeSave, though one could certainly argue that they should be.

garymm commented 3 weeks ago

@cpgaffney1 would you accept a PR adding a better error message for this case?

cpgaffney1 commented 2 weeks ago

Sure, feel free. Appreciate it!