google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
301 stars 36 forks source link

Command that fails in python 3.8 #375

Closed sshleifer closed 1 year ago

sshleifer commented 1 year ago

I have a command that runs successfully on TPU and fails on CPU. The CPU machine has more RAM so I don't think thats the issue. Same orbax version = 2.3.1

If I don't set concurrent_gb=100 I get

ValueError: Requested more bytes than we reserved space for: 96636764160 > 96000000000

in both environments.

so I set it to 100 and then get

RuntimeError: Task <Task pending name='Task-23' coro=<async_deserialize.<locals>.cb() running at /home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py:261> cb=[gather.<locals>._done_callback() at /home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/tasks.py:769]> got Future <Future pending> attached to a different loop

on CPU but success on TPU.

Any idea?

Code it's hitting

def create_orbax_checkpoint_manager(p):
        h = checkpoint.PyTreeCheckpointHandler(concurrent_gb=100)
        mstate_ckptr = Checkpointer(h) # also fails w AsyncCheckpointer
        mngr = CheckpointManager(
            p,
            checkpointers={
                "model_state": mstate_ckptr,
                "dataloader_state": Checkpointer(DataloaderHandler()),
                "config": Checkpointer(TrainConfigHandler()),
                "train_rng": Checkpointer(DataloaderHandler()),
            },
            options=CheckpointManagerOptions(create=True, cleanup_tmp_directories=True),
        )
        return mngr

    restore_args = jax.tree_util.tree_map(
        map_to_pspec, unboxed_train_state, state_mesh_annotations
    )

        ckpt_manager = create_orbax_checkpoint_manager(
            load_parameters_path, 
        )
        logger.info(f"restoring state from {load_parameters_path=}")
        items = {MS: {PARAMS: unboxed_train_state.params}}
        kw = {MS: {"restore_args": {PARAMS: restore_args.params}}}
        params = ckpt_manager.restore(step=1, items=items, restore_kwargs=kw)

Command

(On TPU i just remove the env vars)

TPU_ACCELERATOR_TYPE='' JAX_PLATFORM='cpu' XLA_FLAGS="--xla_force_host_platform_device_count=4" CUDA_VISIBLE_DEVICES=""  JAX_CACHE_DIR=$HOME/.jax_cache python \
  eval.py -m exp.run_name=108b_heather_B16_dringus model=108b  \
exp.load_parameters_path=/mnt/resource_nvme/heather.0623.scale108b.jax.bf16.noep  \ exp.checkpoint_dir=/mnt/resource_nvme/heather.0623.scale108b.jax.sharded.noep "$@"
  "$@"

Traceback

Traceback (most recent call last):
  File "eval.py", line 279, in eval_main
    eval_loop(cfg)
  File "eval.py", line 166, in eval_loop
    (state, _, state_mesh_annotations, _,) = max_utils.setup_initial_state(
  File "/home/sam/character-tech/maxtext/MaxText/max_utils.py", line 424, in setup_initial_state
    ) = checkpointing.load_state_if_possible(
  File "/home/sam/character-tech/maxtext/MaxText/checkpointing.py", line 361, in load_state_if_possible
    params = ckpt_manager.restore(step=1, items=items, restore_kwargs=kw)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpoint_manager.py", line 565, in restore
    restored_items = self._restore_impl(
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpoint_manager.py", line 597, in _restore_impl
    restored[item_name] = self._checkpointers[item_name].restore(
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpointer.py", line 97, in restore
    restored = self._handler.restore(directory, *args, item=item, **kwargs)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 569, in restore
    restored_item = asyncio.run(_restore())
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
    return future.result()
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 566, in _restore
    flat = await asyncio.gather(*flat)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/lazy_utils.py", line 48, in maybe_get_async
    return await value.get_async(*args, **kwargs)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/lazy_utils.py", line 30, in get_async
    return await self._get_fn(*args, **kwargs)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 448, in _deserialize
    return await handler.deserialize(info, args)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/type_handlers.py", line 541, in deserialize
    return await serialization.async_deserialize(
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 287, in async_deserialize
    return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 55, in create_async_array_from_callback
    dbs = await asyncio.gather(*future_arrays)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 261, in cb
    await byte_limiter.wait_for_bytes(requested_bytes)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 130, in wait_for_bytes
    await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/locks.py", line 400, in wait_for
    await self.wait()
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/locks.py", line 373, in wait
    await fut
RuntimeError: Task <Task pending name='Task-23' coro=<async_deserialize.<locals>.cb() running at /home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py:261> cb=[gather.<locals>._done_callback() at /home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/tasks.py:769]> got Future <Future pending> attached to a different loop
sshleifer commented 1 year ago

Might be
Python 3.8.12 on failing vs Python 3.11.3 on succeeding

sshleifer commented 1 year ago

upgrading python resolved.