PredictiveIntelligenceLab / jaxpi

Other
231 stars 52 forks source link

Error in restoring checkpoint for second time window #4

Closed Elappnano closed 8 months ago

Elappnano commented 9 months ago

Hello the following error appears when trying to restore checkpoints for the second time window

Exception has occurred: AttributeError 'numpy.ndarray' object has no attribute 'sharding'

sifanexisted commented 8 months ago

Hello, have you had a chance to experiment with the latest version of JAX-PI? I believe I've resolved the issue by converting all NumPy arrays into JAX arrays as demonstrated below:

def restore_checkpoint(state, workdir, step=None):
    # check if passed state is in a sharded state
    # if so, reduce to a single device sharding

    if isinstance(
        jax.tree_map(lambda x: jnp.array(x).sharding, jax.tree_leaves(state.params))[0],
        jax.sharding.PmapSharding,
    ):
        state = jax.tree_map(lambda x: x[0], state)

    # ensuring that we're in a single device setting
    assert isinstance(
        jax.tree_map(lambda x: jnp.array(x).sharding, jax.tree_leaves(state.params))[0],
        jax.sharding.SingleDeviceSharding,
    )

    state = checkpoints.restore_checkpoint(workdir, state, step=step)
    return state
Elappnano commented 8 months ago

the issue is resolved. thanks