I have a command that runs successfully on TPU and fails on CPU. The CPU machine has more RAM so I don't think thats the issue.
Same orbax version = 2.3.1
If I don't set concurrent_gb=100 I get
ValueError: Requested more bytes than we reserved space for: 96636764160 > 96000000000
in both environments.
so I set it to 100 and then get
RuntimeError: Task <Task pending name='Task-23' coro=<async_deserialize.<locals>.cb() running at /home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py:261> cb=[gather.<locals>._done_callback() at /home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/tasks.py:769]> got Future <Future pending> attached to a different loop
Traceback (most recent call last):
File "eval.py", line 279, in eval_main
eval_loop(cfg)
File "eval.py", line 166, in eval_loop
(state, _, state_mesh_annotations, _,) = max_utils.setup_initial_state(
File "/home/sam/character-tech/maxtext/MaxText/max_utils.py", line 424, in setup_initial_state
) = checkpointing.load_state_if_possible(
File "/home/sam/character-tech/maxtext/MaxText/checkpointing.py", line 361, in load_state_if_possible
params = ckpt_manager.restore(step=1, items=items, restore_kwargs=kw)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpoint_manager.py", line 565, in restore
restored_items = self._restore_impl(
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpoint_manager.py", line 597, in _restore_impl
restored[item_name] = self._checkpointers[item_name].restore(
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpointer.py", line 97, in restore
restored = self._handler.restore(directory, *args, item=item, **kwargs)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 569, in restore
restored_item = asyncio.run(_restore())
File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
return future.result()
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 566, in _restore
flat = await asyncio.gather(*flat)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/lazy_utils.py", line 48, in maybe_get_async
return await value.get_async(*args, **kwargs)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/lazy_utils.py", line 30, in get_async
return await self._get_fn(*args, **kwargs)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 448, in _deserialize
return await handler.deserialize(info, args)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/type_handlers.py", line 541, in deserialize
return await serialization.async_deserialize(
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 287, in async_deserialize
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 55, in create_async_array_from_callback
dbs = await asyncio.gather(*future_arrays)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 261, in cb
await byte_limiter.wait_for_bytes(requested_bytes)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 130, in wait_for_bytes
await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/locks.py", line 400, in wait_for
await self.wait()
File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/locks.py", line 373, in wait
await fut
RuntimeError: Task <Task pending name='Task-23' coro=<async_deserialize.<locals>.cb() running at /home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py:261> cb=[gather.<locals>._done_callback() at /home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/tasks.py:769]> got Future <Future pending> attached to a different loop
I have a command that runs successfully on TPU and fails on CPU. The CPU machine has more RAM so I don't think thats the issue. Same orbax version = 2.3.1
If I don't set concurrent_gb=100 I get
in both environments.
so I set it to 100 and then get
on CPU but success on TPU.
Any idea?
Code it's hitting
Command
(On TPU i just remove the env vars)
Traceback