google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
286 stars 28 forks source link

TypeError: write(): incompatible function arguments. #1039

Open FranzKnut opened 2 months ago

FranzKnut commented 2 months ago

After upgrading to jax==0.4.31 I am seeing this error when trying to save a model using the PyTreeCheckpointer. Downgrading to 0.4.30 fixed it for now.

  File ".../site-packages/orbax/checkpoint/checkpointer.py", line 151, in save
    self._handler.save(tmpdir, args=ckpt_args)
  File ".../site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 500, in save
    super().save(directory, args=args)
  File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 615, in save
    asyncio.run(async_save(directory, *args, **kwargs))
  File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/runners.py", line 190, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 608, in async_save
    commit_futures = await self.async_save(*args, **kwargs)  # pytype: disable=bad-return-type
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 489, in async_save
    return await super().async_save(directory, args=args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 568, in async_save
    commit_futures = await asyncio.gather(*serialize_ops)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/orbax/checkpoint/type_handlers.py", line 1350, in serialize
    await asyncio.gather(*synchronous_ops)
  File ".../site-packages/jax/experimental/array_serialization/serialization.py", line 304, in async_serialize
    return await asyncio.gather(*future_write_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/jax/experimental/array_serialization/serialization.py", line 284, in _write_array
    write_future = t[shard.index].write(
                   ^^^^^^^^^^^^^^^^^^^^^
TypeError: write(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorstore.TensorStore, source: Union[tensorstore.TensorStore, numpy.typing.ArrayLike]) -> tensorstore.WriteFutures

Invoked with: TensorStore({
  'base': {
    'assume_metadata': True,
    'driver': 'zarr',
    'dtype': 'float32',
    'kvstore': {
      'base': {
        'driver': 'file',
        'path': '.../7afc2a53292a409e98a1a41b.orbax-checkpoint-tmp-0/ocdbt.process_0/',
      },
      'cache_pool': 'cache_pool#ocdbt',
      'config': {
        'compression': {'id': 'zstd'},
        'max_decoded_node_bytes': 100000000,
        'max_inline_value_bytes': 1024,
        'uuid': 'da97f553e0a14b1d4ee7c13b8c741510',
        'version_tree_arity_log2': 4,
      },
      'driver': 'ocdbt',
      'experimental_read_coalescing_interval': '1ms',
      'experimental_read_coalescing_merged_bytes': 500000000000,
      'experimental_read_coalescing_threshold_bytes': 1000000,
      'path': '0.batch_stats.encoder.BatchNorm_0.mean/',
    },
    'metadata': {
      'chunks': [8],
      'compressor': {'id': 'zstd', 'level': 1},
      'dimension_separator': '.',
      'dtype': '<f4',
      'fill_value': None,
      'filters': None,
      'order': 'C',
      'shape': [8],
      'zarr_format': 2,
    },
    'recheck_cached_data': False,
    'recheck_cached_metadata': False,
  },
  'context': {
    'cache_pool': {},
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'cast',
  'dtype': 'float32',
  'transform': {'input_exclusive_max': [[8]], 'input_inclusive_min': [0]},
}), array([-0.81411403, -0.17954189,  0.06563368, -0.30694914, -0.9737644 ,
        0.6138028 ,  0.8517958 ,  0.22747818], dtype=float32); kwargs: can_reference_source_data_indefinitely=True
cpgaffney1 commented 2 months ago

Upgrade Orbax version? At head we don't depend on jax/experimental/array_serialization/serialization.py.