Closed easonyang1996 closed 8 months ago
Thanks for reporting. @chrisc36 @sangho-vision and @zcczhang can you help check this? I will check this this weekend if there is no respond.
Hi @easonyang1996, thanks for the interest! I have tried to load the large- 3m ckpt within both GPU and TPU machines and it worked on my end. Also could you please double-check if the model weights downloaded by
aws s3 --no-sign-request cp --recursive s3://ai2-prior-uio/public/uio2-checkpoints/large-3m large-3m --exclude "state*"
especially the last arg --exclude "state*"
? After so, all sub-directories are supposed to start with target.
Thanks for the reply @jiasenlu @zcczhang! I have redownloaded the checkpoint with aws s3 --no-sign-request cp --recursive s3://ai2-prior-uio/public/uio2-checkpoints/large-3m large-3m --exclude "state*"
and double-checked that all sub-directories in large-3m/ are started with "target".
however, I still get an error. Codes are as follows:
>>> from t5x.examples.unified_io import utils as uio_utils
>>> from t5x import partitioning
>>> model = uio_utils.get_model('large', dtype="float32")
>>> partitioner = partitioning.PjitPartitioner(num_partitions=1)
>>> parameters, param_axes = uio_utils.get_parameters(model, './large-3m', partitioner)
WARNING:tensorflow:From /home/yangpengshuai/conda_envs/UnifiedIO2/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:absl:T5 library uses PAD_ID=0, which is different from the sentencepiece vocabulary, which defines pad_id=-1
/home/yangpengshuai/unified-io-2-main/t5x/examples/unified_io/models.py:131: UserWarning: Explicitly requested dtype 0.0 requested in ones is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
{k: jnp.ones(input_shapes[k], input_types[k]) for k in input_shapes},
/home/yangpengshuai/conda_envs/UnifiedIO2/lib/python3.10/site-packages/jax/_src/lib/xla_bridge.py:553: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/yangpengshuai/unified-io-2-main/t5x/examples/unified_io/utils.py", line 107, in get_parameters
).restore([model_checkpoint], RestoreCheckpointConfig(model_checkpoint)).params
File "/home/yangpengshuai/unified-io-2-main/t5x/utils.py", line 396, in restore
self._checkpointer.restore(
File "/home/yangpengshuai/unified-io-2-main/t5x/utils.py", line 290, in restore
return self._restore_checkpointer.restore(
File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 1015, in restore
state_dict = self._read_state_from_tensorstore(
File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 1126, in _read_state_from_tensorstore
state_dict = _run_future_tree(future_state_dict)
File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 165, in _run_future_tree
leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
File "/home/yangpengshuai/conda_envs/UnifiedIO2/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoint_importer_vqgan.py", line 116, in _get_and_cast
arr = await self._get_fn() # pytype: disable=bad-return-type
File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 1081, in get_fn
arr = await _read_ts(
File "/home/yangpengshuai/unified-io-2-main/t5x/checkpoints.py", line 1584, in _read_ts
arr = await gda_serialization.async_deserialize(
File "/home/yangpengshuai/conda_envs/UnifiedIO2/lib/python3.10/site-packages/jax/experimental/gda_serialization/serialization.py", line 241, in async_deserialize
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT)
ValueError: Error opening "zarr" driver: Error reading local file "./large-3m/target.audio_token_embedder.embedding/.zarray": Invalid key: "./large-3m/target.audio_token_embedder.embedding/.zarray" [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"file_io_sync\":true},\"driver\":\"zarr\",\"dtype\":\"float32\",\"kvstore\":{\"driver\":\"file\",\"path\":\"./large-3m/target.audio_token_embedder.embedding/\"},\"metadata\":{\"shape\":[8320,1024]}}'] [source locations='tensorstore/kvstore/file/file_key_value_store.cc:662\ntensorstore/kvstore/kvstore.cc:378\ntensorstore/driver/driver.cc:114']
I don't know what causes the error when reading "./large-3m/target.audio_token_embedder.embedding/.zarray"
Ah it seems because of the directory ./
--- I could load with large-3m
but not for ./large-3m
. Could you first try to remove it or just use the absolute path? I'll investigate it later for more flexible solutions in the backend.
Added the fix for this and with several quick tests, the FULL_CKPT_PATH
in the demo can now be flexibly specified. Feel free to reopen the issue if the fix still does not work!
Yes, you are right! Thanks you!
In demo.ipynb, load model and parameters,
parameters, param_axes = uio_utils.get_parameters(model, './large-3m/', partitioner)
get: ValueError: Error opening "zarr" driver: Error reading local file "./large-3m/state.param_states.audio_token_embedder.embedding.m/.zarray": Invalid key: "./large-3m/state.param_states.audio_token_embedder.embedding.m/.zarray" [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{},\"file_io_sync\":true},\"driver\":\"zarr\",\"dtype\":\"float32\",\"kvstore\":{\"driver\":\"file\",\"path\":\"./large-3m/state.param_states.audio_token_embedder.embedding.m/\"},\"metadata\":{\"shape\":[8320,1024]}}'] [source locations='tensorstore/kvstore/file/file_key_value_store.cc:662\ntensorstore/kvstore/kvstore.cc:378\ntensorstore/driver/driver.cc:114']