Open ssenan opened 4 months ago
I just encountered the same issue in a different project. Even the toy example from the official tutorial fails. I'm running the following code on a TPUv4-16.
import numpy as np
import jax
import orbax
import orbax.checkpoint as ocp
from flax.training import orbax_utils
# from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions, PyTree
# from etils import epath
jax.distributed.initialize()
ocp.multihost.utils.initialize_runtime_to_distributed_ids()
CKPT_DIR = '/home/yfan/orbax_test'
params = [12, {'bar': np.array((2, 3))}, [1, 4, 10]]
mgr_options = orbax.checkpoint.CheckpointManagerOptions(
create=True, max_to_keep=3, keep_period=2, step_prefix='test')
ckpt_mgr = orbax.checkpoint.CheckpointManager(
CKPT_DIR,
orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)
save_args = orbax_utils.save_args_from_target(params)
ckpt_mgr.save(5, params, save_kwargs={'save_args': save_args})
The error says:
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.
Traceback (most recent call last):
File "/home/yfan/maxtext/MaxText/test_orbax.py", line 61, in <module>
ckpt_mgr = orbax.checkpoint.CheckpointManager(
File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 617, in __init__
self._checkpoints = self._load_checkpoint_infos()
File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1268, in _load_checkpoint_infos
steps = utils.checkpoint_steps(
File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 798, in checkpoint_steps
return _checkpoint_steps(checkpoint_dir)
File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 783, in _checkpoint_steps
step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 754, in checkpoint_steps_paths
raise ValueError(f'Path {checkpoint_dir} does not exist.')
ValueError: Path /home/yfan/orbax_test does not exist.
It feels like something is wrong with Orbax multi-host saving? The same code works perfectly on a single host with TPUv4-8.
I think the issue is that Orbax is assuming the root directory is a global storage. If you specify /tmp/....
, it is local to each process. Directory creation is almost always handled by one "primary" process, while the others just wait for the directory to be created. This avoids duplicate requests to the filesystem, which can become a problem if there are many processes. We recently made a change here to de-dup some of the directory creation requests, so that is possibly why you are seeing this recently. Try using a root directory that is global, e.g. on GCS.
Also, @YUE-FAN why are you using initialize_runtime_to_distributed_ids
, and how even find out about this function? This is experimental code and I don't think your use case needs it.
Hi @cpgaffney1 thanks for the response! I changed the directory to a GCS bucket and currently see a different error, but still related to a directory not being found. I believe I have set all the correct permissions as I am able to write other files to the storage bucket with no issue. I also did this test directly installing Orbax-checkpoint from the GitHub rather than the pypi release and got the same error.
Traceback (most recent call last):
File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 47, in <module>
test_checkpointing()
File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 34, in test_checkpointing
with ocp.CheckpointManager(path, item_names=("state", "custom_data"), metadata=global_metadata) as mngr:
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 634, in __init__
self._save_metadata(metadata)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1331, in _save_metadata
self._metadata_checkpointer.save(path, metadata)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 165, in save
tmpdir = utils.create_tmp_directory(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 626, in create_tmp_directory
checkpoint_metadata_store.write(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 238, in write
self._store_impl.write(checkpoint_path, checkpoint_metadata)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 126, in write
raise ValueError(f'Checkpoint path does not exist: {checkpoint_path}')
ValueError: Checkpoint path does not exist: gs://dnadiffusion-bucket/checkpoints/metadata
I think you don't actually have the latest version of the code. See checkpointer.py. At head there is no reference to utils.create_tmp_directory
, and it calls tmpdir = self.create_temporary_path(directory)
instead. Try deleting and reinstalling, and see if things work then?
Also checkpoint orbax.checkpoint.__version__
.
@cpgaffney1 i am using 0.5.20 version of orbax.checkpoint what all we need to update ? orbax , orbax.checkpoint and orbax-export too ?
@cpgaffney1 You're right I was actually accidentally reinstalling the pypi version when loading my package, but installing directly from the GitHub does resolve the issue (along with syncing all hosts before saving). Is this fix included the the 0.5.21 release on pypi from yesterday or will it be included in the next release?
Otherwise, thanks for all your assistance and feel free to close this issue!
I think the issue is that Orbax is assuming the root directory is a global storage. If you specify
/tmp/....
, it is local to each process. Directory creation is almost always handled by one "primary" process, while the others just wait for the directory to be created. This avoids duplicate requests to the filesystem, which can become a problem if there are many processes. We recently made a change here to de-dup some of the directory creation requests, so that is possibly why you are seeing this recently. Try using a root directory that is global, e.g. on GCS.Also, @YUE-FAN why are you using
initialize_runtime_to_distributed_ids
, and how even find out about this function? This is experimental code and I don't think your use case needs it.
Thanks! Using GCS solves the problem perfectly :D
I was using MaxText, ocp.multihost.utils.initialize_runtime_to_distributed_ids()
was called here after the jax.distributed.initialize()
ocp.multihost.utils.initialize_runtime_to_distributed_ids()
is necessary for a special experimental feature used by Maxtext, but you don't need it in toy examples.
Hi @cpgaffney1, oddly I managed to get it working for my toy example, but migrating back to main library I'm seeing an error that's similar again.
Traceback (most recent call last):
File "/home/simonsenan/dnadiffusion-jax/train.py", line 224, in main
train(
File "/home/simonsenan/dnadiffusion-jax/train.py", line 195, in train
checkpoint_manager.save(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1078, in save
self._checkpointer.save(save_directory, args=args)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 193, in save
tmpdir = self.create_temporary_path(directory)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 157, in create_temporary_path
tmpdir.create()
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/atomicity.py", line 441, in create
return _create_tmp_directory(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/atomicity.py", line 192, in _create_tmp_directory
checkpoint_metadata_store.write(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 238, in write
self._store_impl.write(checkpoint_path, checkpoint_metadata)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 126, in write
raise ValueError(f'Checkpoint path does not exist: {checkpoint_path}')
ValueError: Checkpoint path does not exist: gs://dnadiffusion-bucket/checkpoints/105
I'm still seeing it successfully write a checkpoint to my google storage, so I wonder if this is coming from one of the non primary hosts? Let me know if I'm still overlooking something, but I did confirm as per your last suggestion that I am running the latest version of Orbax (0.5.22 installing directly from the GitHub)
@niketkumar could you take a look at this? It's a bit weird because checkpoint_metadata_store
should only be called on the primary process, which would only happen immediately after creating the tmp directory. And this initial metadata write seems to be synchronous (ideally it would not be, actually...), so I don't think the path is getting deleted before the write completes.
Based on the above error stack, it is not likely that checkpoint_metadata_store
was called from a non-primary host.
The checkpoint_metadata_store write is called right after the tmp
dir creation, so it is highly unlikely that it was deleted.
(Looking at the stack, it seems that CheckpointManager initializes a Checkpointer, not AsyncCheckpointer. For a Checkpointer, we only allow synchronous checkpoint_metadata_store
.)
@ssenan You can check if the current host is a primary or not with multihost.is_primary_host(...)
api. Can you please log/print it in your run and verify the same?
@ssenan I didn't get what you meant by
I'm still seeing it successfully write a checkpoint to my google storage
I am not sure how the run managed to write a checkpoint in spite of the above error. Can you please explain your scenario and observations a bit?
@niketkumar Sorry for the delay, I will check on this / elaborate further in a couple days.
Hi Everyone,
I've been trying to checkpoint training using Orbax in a project linked here project. When I test the code locally I'm able checkpoint successfully, but when training in a TPU v4-32 VM I encounter an issue related to directories not being found.
I've put together a simpler example using code from the Orbax docs, which outputs a similar error.
which appears to succeed on the process index, but fail on the rest of the hosts.
Here is the error I see:
Finally, here's the command I normally use when installing all my dependencies on the vm
Is this issue related to my sharding and the directory not being created on all of the hosts or something on the Orbax end? Any assistance is greatly appreciated!