I use load_t5x_checkpoint to help convert t5x checkpoints to hf checkpoints. Where I load the t5x checkpoint path with that function.
However, there seems to be a change in the codebase that now results in a ValueError when trying to convert the new Orbax checkpoint (I don't know how different it is with previous versions. Only that there is an additional directory of state before the checkpoint steps.)
using variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) results in
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoints.py", line 1925, in load_t5x_checkpoint
ckpt_optimizer_state = _get_optimizer_state_dict(
File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoints.py", line 1692, in _get_optimizer_state_dict
raise ValueError(
ValueError: Checkpoint versions earlier than 2 are not supported. Got version: 0
While adding remap=False to avoid the above results in
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path, remap=False)
File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoints.py", line 1979, in load_t5x_checkpoint
state_dict = jax.tree_util.tree_map(
File "/admin/home-lintangsutawika/miniconda3/envs/t5v2/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/admin/home-lintangsutawika/miniconda3/envs/t5v2/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoints.py", line 1975, in _create_lazy_awaitable_array
return LazyAwaitableArray.from_tensor_store_spec_or_array(
File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoint_importer.py", line 186, in from_tensor_store_spec_or_array
return cls.from_array(maybe_ts_spec, get_fn, dtype=dtype)
File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoint_importer.py", line 171, in from_array
dtype = array.dtype
AttributeError: 'str' object has no attribute 'dtype'
I use
load_t5x_checkpoint
to help convert t5x checkpoints to hf checkpoints. Where I load the t5x checkpoint path with that function.However, there seems to be a change in the codebase that now results in a ValueError when trying to convert the new Orbax checkpoint (I don't know how different it is with previous versions. Only that there is an additional directory of
state
before the checkpoint steps.)using
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
results inWhile adding
remap=False
to avoid the above results in