google-research / t5x

Apache License 2.0
2.58k stars 297 forks source link

`load_t5x_checkpoint` Does not work with new checkpoints #1524

Closed lintangsutawika closed 4 months ago

lintangsutawika commented 4 months ago

I use load_t5x_checkpoint to help convert t5x checkpoints to hf checkpoints. Where I load the t5x checkpoint path with that function.

However, there seems to be a change in the codebase that now results in a ValueError when trying to convert the new Orbax checkpoint (I don't know how different it is with previous versions. Only that there is an additional directory of state before the checkpoint steps.)

using variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) results in

    variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
  File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoints.py", line 1925, in load_t5x_checkpoint
    ckpt_optimizer_state = _get_optimizer_state_dict(                                                                        
  File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoints.py", line 1692, in _get_optimizer_state_dict
    raise ValueError(
ValueError: Checkpoint versions earlier than 2 are not supported. Got version: 0

While adding remap=False to avoid the above results in

    variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path, remap=False)
  File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoints.py", line 1979, in load_t5x_checkpoint
    state_dict = jax.tree_util.tree_map(
  File "/admin/home-lintangsutawika/miniconda3/envs/t5v2/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/admin/home-lintangsutawika/miniconda3/envs/t5v2/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoints.py", line 1975, in _create_lazy_awaitable_array
    return LazyAwaitableArray.from_tensor_store_spec_or_array(
  File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoint_importer.py", line 186, in from_tensor_store_spec_or_array
    return cls.from_array(maybe_ts_spec, get_fn, dtype=dtype)
  File "/weka/lintangsutawika/01-t5v2/t5x/t5x/checkpoint_importer.py", line 171, in from_array
    dtype = array.dtype
AttributeError: 'str' object has no attribute 'dtype'
lintangsutawika commented 4 months ago

Nvm, I just used --gin.train.use_orbax=False to avoid this issue.