Closed Elappnano closed 8 months ago
Hello, have you had a chance to experiment with the latest version of JAX-PI? I believe I've resolved the issue by converting all NumPy arrays into JAX arrays as demonstrated below:
def restore_checkpoint(state, workdir, step=None):
# check if passed state is in a sharded state
# if so, reduce to a single device sharding
if isinstance(
jax.tree_map(lambda x: jnp.array(x).sharding, jax.tree_leaves(state.params))[0],
jax.sharding.PmapSharding,
):
state = jax.tree_map(lambda x: x[0], state)
# ensuring that we're in a single device setting
assert isinstance(
jax.tree_map(lambda x: jnp.array(x).sharding, jax.tree_leaves(state.params))[0],
jax.sharding.SingleDeviceSharding,
)
state = checkpoints.restore_checkpoint(workdir, state, step=step)
return state
the issue is resolved. thanks
Hello the following error appears when trying to restore checkpoints for the second time window