google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
301 stars 36 forks source link

Restoring flax model checkpoints using orbax throws ValueError #1305

Open ybangaru opened 1 week ago

ybangaru commented 1 week ago

The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory. Version being used

orbax-checkpoint                   0.8.0

from flax.training import orbax_utils
import orbax.checkpoint

directory_gen_path = "checkpoints_loc"
orbax_checkpointer_gen = orbax.checkpoint.PyTreeCheckpointer()
gen_options = orbax.checkpoint.CheckpointManagerOptions(save_interval_steps=5, create=True)
gen_checkpoint_manager = orbax.checkpoint.CheckpointManager(
    directory_gen_path, orbax_checkpointer_gen, gen_options
)

def save_model_checkpoints(step_, generator_state, generator_batch_stats):

    gen_ckpt = {
        "model": generator_state,
        "batch_stats": generator_batch_stats,
    }

    save_args_gen = orbax_utils.save_args_from_target(gen_ckpt)
    gen_checkpoint_manager.save(step_, gen_ckpt, save_kwargs={"save_args": save_args_gen})

def load_model_checkpoints(generator_state, generator_batch_stats):
    gen_target = {
        "model": generator_state,
        "batch_stats": generator_batch_stats,
    }

    latest_step = gen_checkpoint_manager.latest_step()
    gen_ckpt = gen_checkpoint_manager.restore(latest_step, items=gen_target)
    generator_state = gen_ckpt["model"]
    generator_batch_stats = gen_ckpt["batch_stats"]

    return generator_state, generator_batch_stats

The training of the model was done on a GPU and loading the state onto GPU device works fine, however, when trying to load the model to cpu, the following error is being thrown by the orbax checkpoint manager's restore method

/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/type_handlers.py:1386: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(
ERROR:root:Device cuda:0 was not found in jax.local_devices().
ERROR:root:Device cuda:0 was not found in jax.local_devices()
.......
......
......

  File "/user/yashbangaru/simgan/pysrc/package/model_handlers.py", line 453, in load_model_checkpoints
    gen_ckpt = self.gen_checkpoint_manager.restore(latest_step, items=gen_target)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1356, in restore
    restored = self._checkpointer.restore(restore_directory, args=args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpointer.py", line 239, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 811, in restore
    restored[item_name] = handler.restore(
                          ^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 769, in restore
    return self._handler_impl.restore(directory, args=args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 694, in restore
    tree_memory_size, restored_item = asyncio_utils.run_sync(
                                      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
    return asyncio.run(coro)
           ^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 551, in _maybe_deserialize
    deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/type_handlers.py", line 1442, in deserialize
    ret = await asyncio.gather(*deserialize_ops)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/serialization.py", line 591, in async_deserialize
    raise ValueError(
ValueError: sharding passed to deserialization should be specified, concrete and an instance of `jax.sharding.Sharding`. Got None

Please let me know if you'd like any other details, I also added most of the traceback, it's a mess but hope it works

cpgaffney1 commented 3 days ago

You need to pass restore_args when restoring with a different sharding that the one that the checkpoint was saved with. I'd recommend following this documentation: https://orbax.readthedocs.io/en/latest/guides/checkpoint/checkpointing_pytrees.html

ybangaru commented 18 hours ago

I managed to recover the model states and batch norm layer statistics by just changing orbax_checkpointer_gen = orbax.checkpoint.PyTreeCheckpointer() to orbax_checkpointer_gen = orbax.checkpoint.StandardCheckpointer(). However, i also had continuous normalization metrics of the different channels of my 3d arrays which is of the following form as shown in the image, basically a dictionary of integer keys and dict values, unfortunately, i'm not able to recover this by making the foretold change, can you please tell me if you have any thoughts on how i may be able to recover these values on a different device i.e. the cpu?

Image

        norm_data_ckpt = {"data_norm_states": self.data_handler.norm_state_and_config["curr_scaler_state"]}
        save_args_data_norm = orbax_utils.save_args_from_target(norm_data_ckpt)
        self.data_norm_checkpoint_manager.save(
            iter_value, norm_data_ckpt, save_kwargs={"save_args": save_args_data_norm}
        )

        latest_step = self.data_norm_checkpoint_manager.latest_step()
        norm_data_ckpt = self.data_norm_checkpoint_manager.restore(latest_step)

the error is as the following

 File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 95, in __init__
    self.load_checkpoints(checkpoint_flag)
  File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 285, in load_checkpoints
    latest_state = self._load_data_norm_from_checkpoints(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 331, in _load_data_norm_from_checkpoints
    norm_data_ckpt = self.data_norm_checkpoint_manager.restore(latest_step)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1356, in restore
    restored = self._checkpointer.restore(restore_directory, args=args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/async_checkpointer.py", line 429, in restore
    return super().restore(directory, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpointer.py", line 239, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 811, in restore
    restored[item_name] = handler.restore(
                          ^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py", line 220, in restore
    return self._impl.restore(
           ^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 769, in restore
    return self._handler_impl.restore(directory, args=args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 694, in restore
    tree_memory_size, restored_item = asyncio_utils.run_sync(
                                      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
    return asyncio.run(coro)
           ^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 551, in _maybe_deserialize
    deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/type_handlers.py", line 1382, in deserialize
    sharding = arg.sharding.to_jax_sharding()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/metadata/sharding.py", line 290, in to_jax_sharding
    raise ValueError(
ValueError: Device cuda:0 was not found in jax.local_devices().

from the documentation that was shared, i could change both saving and restoring to make it work on different devices but it's really important for me to be able to recover the existing training runs on a different device, thank you.