allenai / unified-io-2

Apache License 2.0
559 stars 27 forks source link

Error when loading large-3m #11

Closed easonyang1996 closed 8 months ago

easonyang1996 commented 8 months ago

In demo.ipynb, load model and parameters, parameters, param_axes = uio_utils.get_parameters(model, './large-3m/', partitioner)

get: ValueError: Error opening "zarr" driver: Error reading local file "./large-3m/state.param_states.audio_token_embedder.embedding.m/.zarray": Invalid key: "./large-3m/state.param_states.audio_token_embedder.embedding.m/.zarray" [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{},\"file_io_sync\":true},\"driver\":\"zarr\",\"dtype\":\"float32\",\"kvstore\":{\"driver\":\"file\",\"path\":\"./large-3m/state.param_states.audio_token_embedder.embedding.m/\"},\"metadata\":{\"shape\":[8320,1024]}}'] [source locations='tensorstore/kvstore/file/file_key_value_store.cc:662\ntensorstore/kvstore/kvstore.cc:378\ntensorstore/driver/driver.cc:114']

jiasenlu commented 8 months ago

Thanks for reporting. @chrisc36 @sangho-vision and @zcczhang can you help check this? I will check this this weekend if there is no respond.

zcczhang commented 8 months ago

Hi @easonyang1996, thanks for the interest! I have tried to load the large- 3m ckpt within both GPU and TPU machines and it worked on my end. Also could you please double-check if the model weights downloaded by

aws s3 --no-sign-request cp --recursive s3://ai2-prior-uio/public/uio2-checkpoints/large-3m large-3m --exclude "state*"  

especially the last arg --exclude "state*"? After so, all sub-directories are supposed to start with target.

easonyang1996 commented 8 months ago

Thanks for the reply @jiasenlu @zcczhang! I have redownloaded the checkpoint with aws s3 --no-sign-request cp --recursive s3://ai2-prior-uio/public/uio2-checkpoints/large-3m large-3m --exclude "state*" and double-checked that all sub-directories in large-3m/ are started with "target". however, I still get an error. Codes are as follows:

>>> from t5x.examples.unified_io import utils as uio_utils
>>> from t5x import partitioning
>>> model = uio_utils.get_model('large', dtype="float32")
>>> partitioner = partitioning.PjitPartitioner(num_partitions=1)
>>> parameters, param_axes = uio_utils.get_parameters(model, './large-3m', partitioner)
WARNING:tensorflow:From /home/yangpengshuai/conda_envs/UnifiedIO2/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:absl:T5 library uses PAD_ID=0, which is different from the sentencepiece vocabulary, which defines pad_id=-1
/home/yangpengshuai/unified-io-2-main/t5x/examples/unified_io/models.py:131: UserWarning: Explicitly requested dtype 0.0 requested in ones is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  {k: jnp.ones(input_shapes[k], input_types[k]) for k in input_shapes},
/home/yangpengshuai/conda_envs/UnifiedIO2/lib/python3.10/site-packages/jax/_src/lib/xla_bridge.py:553: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
  warnings.warn(
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/yangpengshuai/unified-io-2-main/t5x/examples/unified_io/utils.py", line 107, in get_parameters
    ).restore([model_checkpoint], RestoreCheckpointConfig(model_checkpoint)).params
  File "/home/yangpengshuai/unified-io-2-main/t5x/utils.py", line 396, in restore
    self._checkpointer.restore(
  File "/home/yangpengshuai/unified-io-2-main/t5x/utils.py", line 290, in restore
    return self._restore_checkpointer.restore(
  File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 1015, in restore
    state_dict = self._read_state_from_tensorstore(
  File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 1126, in _read_state_from_tensorstore
    state_dict = _run_future_tree(future_state_dict)
  File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 165, in _run_future_tree
    leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
  File "/home/yangpengshuai/conda_envs/UnifiedIO2/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoint_importer_vqgan.py", line 116, in _get_and_cast
    arr = await self._get_fn()  # pytype: disable=bad-return-type
  File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 1081, in get_fn
    arr = await _read_ts(
  File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 1584, in _read_ts
    arr = await gda_serialization.async_deserialize(
  File "/home/yangpengshuai/conda_envs/UnifiedIO2/lib/python3.10/site-packages/jax/experimental/gda_serialization/serialization.py", line 241, in async_deserialize
    t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT)
ValueError: Error opening "zarr" driver: Error reading local file "./large-3m/target.audio_token_embedder.embedding/.zarray": Invalid key: "./large-3m/target.audio_token_embedder.embedding/.zarray" [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"file_io_sync\":true},\"driver\":\"zarr\",\"dtype\":\"float32\",\"kvstore\":{\"driver\":\"file\",\"path\":\"./large-3m/target.audio_token_embedder.embedding/\"},\"metadata\":{\"shape\":[8320,1024]}}'] [source locations='tensorstore/kvstore/file/file_key_value_store.cc:662\ntensorstore/kvstore/kvstore.cc:378\ntensorstore/driver/driver.cc:114']

I don't know what causes the error when reading "./large-3m/target.audio_token_embedder.embedding/.zarray"

zcczhang commented 8 months ago

Ah it seems because of the directory ./ --- I could load with large-3m but not for ./large-3m. Could you first try to remove it or just use the absolute path? I'll investigate it later for more flexible solutions in the backend.

zcczhang commented 8 months ago

Added the fix for this and with several quick tests, the FULL_CKPT_PATH in the demo can now be flexibly specified. Feel free to reopen the issue if the fix still does not work!

easonyang1996 commented 8 months ago

Yes, you are right! Thanks you!