Open ybangaru opened 1 week ago
You need to pass restore_args when restoring with a different sharding that the one that the checkpoint was saved with. I'd recommend following this documentation: https://orbax.readthedocs.io/en/latest/guides/checkpoint/checkpointing_pytrees.html
I managed to recover the model states and batch norm layer statistics by just changing orbax_checkpointer_gen = orbax.checkpoint.PyTreeCheckpointer()
to orbax_checkpointer_gen = orbax.checkpoint.StandardCheckpointer()
. However, i also had continuous normalization metrics of the different channels of my 3d arrays which is of the following form as shown in the image, basically a dictionary of integer keys and dict values, unfortunately, i'm not able to recover this by making the foretold change, can you please tell me if you have any thoughts on how i may be able to recover these values on a different device i.e. the cpu?
norm_data_ckpt = {"data_norm_states": self.data_handler.norm_state_and_config["curr_scaler_state"]}
save_args_data_norm = orbax_utils.save_args_from_target(norm_data_ckpt)
self.data_norm_checkpoint_manager.save(
iter_value, norm_data_ckpt, save_kwargs={"save_args": save_args_data_norm}
)
latest_step = self.data_norm_checkpoint_manager.latest_step()
norm_data_ckpt = self.data_norm_checkpoint_manager.restore(latest_step)
the error is as the following
File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 95, in __init__
self.load_checkpoints(checkpoint_flag)
File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 285, in load_checkpoints
latest_state = self._load_data_norm_from_checkpoints(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 331, in _load_data_norm_from_checkpoints
norm_data_ckpt = self.data_norm_checkpoint_manager.restore(latest_step)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1356, in restore
restored = self._checkpointer.restore(restore_directory, args=args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/async_checkpointer.py", line 429, in restore
return super().restore(directory, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpointer.py", line 239, in restore
restored = self._handler.restore(directory, args=ckpt_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 811, in restore
restored[item_name] = handler.restore(
^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py", line 220, in restore
return self._impl.restore(
^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 769, in restore
return self._handler_impl.restore(directory, args=args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 694, in restore
tree_memory_size, restored_item = asyncio_utils.run_sync(
^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
return asyncio.run(coro)
^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 194, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete
return future.result()
^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 551, in _maybe_deserialize
deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/type_handlers.py", line 1382, in deserialize
sharding = arg.sharding.to_jax_sharding()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/metadata/sharding.py", line 290, in to_jax_sharding
raise ValueError(
ValueError: Device cuda:0 was not found in jax.local_devices().
from the documentation that was shared, i could change both saving and restoring to make it work on different devices but it's really important for me to be able to recover the existing training runs on a different device, thank you.
The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory. Version being used
The training of the model was done on a GPU and loading the state onto GPU device works fine, however, when trying to load the model to cpu, the following error is being thrown by the orbax checkpoint manager's restore method
Please let me know if you'd like any other details, I also added most of the traceback, it's a mess but hope it works