google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
308 stars 36 forks source link

Multi-host Checkpointing Error #999

Open ssenan opened 4 months ago

ssenan commented 4 months ago

Hi Everyone,

I've been trying to checkpoint training using Orbax in a project linked here project. When I test the code locally I'm able checkpoint successfully, but when training in a TPU v4-32 VM I encounter an issue related to directories not being found.

I've put together a simpler example using code from the Orbax docs, which outputs a similar error.

import jax
import numpy as np
import orbax.checkpoint as ocp
from jax.experimental import mesh_utils

def test_checkpointing():
    jax.distributed.initialize()

    if jax.process_index() == 0:
        print("Number of devices: ", jax.device_count())
        print("Local devices: ", jax.local_device_count())

    devices = mesh_utils.create_device_mesh((jax.device_count(),))
    mesh = jax.sharding.Mesh(devices, ("data",))
    sharding = jax.NamedSharding(
        mesh,
        jax.sharding.PartitionSpec(),
    )

    create_sharded_array = lambda x: jax.device_put(x, sharding)
    state = {
        "a": np.arange(16),
        "b": np.ones(16),
    }
    state = jax.tree_util.tree_map(create_sharded_array, state)
    abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
    print(jax.tree_util.tree_map(lambda x: x.shape, state))

    path = ocp.test_utils.erase_and_create_empty("/tmp/checkpoint")

    global_metadata = {'global_property': 'foo'}
    with ocp.CheckpointManager(path, item_names=("state", "custom_data"), metadata=global_metadata) as mngr:
        mngr.save(
            0,
            args=ocp.args.Composite(
                state=ocp.args.PyTreeSave(state),
                custom_data=ocp.args.JsonSave({"lang": "en", "version": 1.2}),
            ),
        )

    print("Checkpoint saved!")

if __name__ == "__main__":
    test_checkpointing()

which appears to succeed on the process index, but fail on the rest of the hosts.

Here is the error I see:

Traceback (most recent call last):
  File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 50, in <module>
    test_checkpointing()
  File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 35, in test_checkpointing
    ckptr.save(0, state)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1110, in save
    self._checkpointer.save(save_directory, args=args)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/async_checkpointer.py", line 328, in save
    commit_ops = asyncio.run(self._handler.async_save(tmpdir, args=ckpt_args))
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 646, in run_until_complete
    return future.result()
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py", line 358, in async_save
    path.mkdir(parents=False, exist_ok=True)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/etils/epath/gpath.py", line 205, in mkdir
    self._backend.mkdir(self._path_str, exist_ok=exist_ok, mode=mode)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/etils/epath/backend.py", line 180, in mkdir
    os.mkdir(path, mode=mode)
FileNotFoundError: [Errno 2] No such file or directory: '/tmp/checkpoint/0.orbax-checkpoint-tmp-0/default'

Finally, here's the command I normally use when installing all my dependencies on the vm

pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html  flax optax pandas numpy scipy wandb tqdm orbax-checkpoint gcsfs

Is this issue related to my sharding and the directory not being created on all of the hosts or something on the Orbax end? Any assistance is greatly appreciated!

YUE-FAN commented 4 months ago

I just encountered the same issue in a different project. Even the toy example from the official tutorial fails. I'm running the following code on a TPUv4-16.

import numpy as np
import jax
import orbax
import orbax.checkpoint as ocp
from flax.training import orbax_utils
# from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions, PyTree
# from etils import epath

jax.distributed.initialize()
ocp.multihost.utils.initialize_runtime_to_distributed_ids()

CKPT_DIR = '/home/yfan/orbax_test'
params = [12, {'bar': np.array((2, 3))}, [1, 4, 10]]

mgr_options = orbax.checkpoint.CheckpointManagerOptions(
  create=True, max_to_keep=3, keep_period=2, step_prefix='test')
ckpt_mgr = orbax.checkpoint.CheckpointManager(
  CKPT_DIR,
  orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)
save_args = orbax_utils.save_args_from_target(params)
ckpt_mgr.save(5, params, save_kwargs={'save_args': save_args})

The error says:

WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.
Traceback (most recent call last):
  File "/home/yfan/maxtext/MaxText/test_orbax.py", line 61, in <module>
    ckpt_mgr = orbax.checkpoint.CheckpointManager(
  File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 617, in __init__
    self._checkpoints = self._load_checkpoint_infos()
  File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1268, in _load_checkpoint_infos
    steps = utils.checkpoint_steps(
  File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 798, in checkpoint_steps
    return _checkpoint_steps(checkpoint_dir)
  File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 783, in _checkpoint_steps
    step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
  File "/home/yfan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 754, in checkpoint_steps_paths
    raise ValueError(f'Path {checkpoint_dir} does not exist.')
ValueError: Path /home/yfan/orbax_test does not exist.

It feels like something is wrong with Orbax multi-host saving? The same code works perfectly on a single host with TPUv4-8.

cpgaffney1 commented 4 months ago

I think the issue is that Orbax is assuming the root directory is a global storage. If you specify /tmp/...., it is local to each process. Directory creation is almost always handled by one "primary" process, while the others just wait for the directory to be created. This avoids duplicate requests to the filesystem, which can become a problem if there are many processes. We recently made a change here to de-dup some of the directory creation requests, so that is possibly why you are seeing this recently. Try using a root directory that is global, e.g. on GCS.

Also, @YUE-FAN why are you using initialize_runtime_to_distributed_ids, and how even find out about this function? This is experimental code and I don't think your use case needs it.

ssenan commented 4 months ago

Hi @cpgaffney1 thanks for the response! I changed the directory to a GCS bucket and currently see a different error, but still related to a directory not being found. I believe I have set all the correct permissions as I am able to write other files to the storage bucket with no issue. I also did this test directly installing Orbax-checkpoint from the GitHub rather than the pypi release and got the same error.

Traceback (most recent call last):
  File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 47, in <module>
    test_checkpointing()
  File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 34, in test_checkpointing
    with ocp.CheckpointManager(path, item_names=("state", "custom_data"), metadata=global_metadata) as mngr:
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 634, in __init__
    self._save_metadata(metadata)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1331, in _save_metadata
    self._metadata_checkpointer.save(path, metadata)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 165, in save
    tmpdir = utils.create_tmp_directory(
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 626, in create_tmp_directory
    checkpoint_metadata_store.write(
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 238, in write
    self._store_impl.write(checkpoint_path, checkpoint_metadata)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 126, in write
    raise ValueError(f'Checkpoint path does not exist: {checkpoint_path}')
ValueError: Checkpoint path does not exist: gs://dnadiffusion-bucket/checkpoints/metadata
cpgaffney1 commented 4 months ago

I think you don't actually have the latest version of the code. See checkpointer.py. At head there is no reference to utils.create_tmp_directory, and it calls tmpdir = self.create_temporary_path(directory) instead. Try deleting and reinstalling, and see if things work then?

Also checkpoint orbax.checkpoint.__version__.

sumanttyagi commented 4 months ago

@cpgaffney1 i am using 0.5.20 version of orbax.checkpoint what all we need to update ? orbax , orbax.checkpoint and orbax-export too ?

ssenan commented 4 months ago

@cpgaffney1 You're right I was actually accidentally reinstalling the pypi version when loading my package, but installing directly from the GitHub does resolve the issue (along with syncing all hosts before saving). Is this fix included the the 0.5.21 release on pypi from yesterday or will it be included in the next release?

Otherwise, thanks for all your assistance and feel free to close this issue!

YUE-FAN commented 4 months ago

I think the issue is that Orbax is assuming the root directory is a global storage. If you specify /tmp/...., it is local to each process. Directory creation is almost always handled by one "primary" process, while the others just wait for the directory to be created. This avoids duplicate requests to the filesystem, which can become a problem if there are many processes. We recently made a change here to de-dup some of the directory creation requests, so that is possibly why you are seeing this recently. Try using a root directory that is global, e.g. on GCS.

Also, @YUE-FAN why are you using initialize_runtime_to_distributed_ids, and how even find out about this function? This is experimental code and I don't think your use case needs it.

Thanks! Using GCS solves the problem perfectly :D

I was using MaxText, ocp.multihost.utils.initialize_runtime_to_distributed_ids() was called here after the jax.distributed.initialize()

cpgaffney1 commented 4 months ago

ocp.multihost.utils.initialize_runtime_to_distributed_ids() is necessary for a special experimental feature used by Maxtext, but you don't need it in toy examples.

ssenan commented 4 months ago

Hi @cpgaffney1, oddly I managed to get it working for my toy example, but migrating back to main library I'm seeing an error that's similar again.

Traceback (most recent call last):
  File "/home/simonsenan/dnadiffusion-jax/train.py", line 224, in main
    train(
  File "/home/simonsenan/dnadiffusion-jax/train.py", line 195, in train
    checkpoint_manager.save(
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1078, in save
    self._checkpointer.save(save_directory, args=args)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 193, in save
    tmpdir = self.create_temporary_path(directory)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 157, in create_temporary_path
    tmpdir.create()
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/atomicity.py", line 441, in create
    return _create_tmp_directory(
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/atomicity.py", line 192, in _create_tmp_directory
    checkpoint_metadata_store.write(
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 238, in write
    self._store_impl.write(checkpoint_path, checkpoint_metadata)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 126, in write
    raise ValueError(f'Checkpoint path does not exist: {checkpoint_path}')
ValueError: Checkpoint path does not exist: gs://dnadiffusion-bucket/checkpoints/105

I'm still seeing it successfully write a checkpoint to my google storage, so I wonder if this is coming from one of the non primary hosts? Let me know if I'm still overlooking something, but I did confirm as per your last suggestion that I am running the latest version of Orbax (0.5.22 installing directly from the GitHub)

cpgaffney1 commented 3 months ago

@niketkumar could you take a look at this? It's a bit weird because checkpoint_metadata_store should only be called on the primary process, which would only happen immediately after creating the tmp directory. And this initial metadata write seems to be synchronous (ideally it would not be, actually...), so I don't think the path is getting deleted before the write completes.

niketkumar commented 3 months ago

Based on the above error stack, it is not likely that checkpoint_metadata_store was called from a non-primary host.

The checkpoint_metadata_store write is called right after the tmp dir creation, so it is highly unlikely that it was deleted.

(Looking at the stack, it seems that CheckpointManager initializes a Checkpointer, not AsyncCheckpointer. For a Checkpointer, we only allow synchronous checkpoint_metadata_store.)

@ssenan You can check if the current host is a primary or not with multihost.is_primary_host(...) api. Can you please log/print it in your run and verify the same?

@ssenan I didn't get what you meant by

I'm still seeing it successfully write a checkpoint to my google storage

I am not sure how the run managed to write a checkpoint in spite of the above error. Can you please explain your scenario and observations a bit?

ssenan commented 3 months ago

@niketkumar Sorry for the delay, I will check on this / elaborate further in a couple days.