Closed Elappnano closed 9 months ago
Do you meet this error when restoring the checkpoint?
If so, please add state = jax.device_get(tree_map(lambda x: x[0], model.state))
to extract the first replica of the state.
No I got this error when run this line
u_pred = u_pred_fn(params, t_coords, coords[:, 0], coords[:, 1])
The error is fixed. Thats because these lines in the utils.py file of jaxpi folder were commented.
if isinstance(
jax.tree_map(lambda x: x.sharding, jax.tree_leaves(state.params))[0],
jax.sharding.PmapSharding,
):
state = jax.tree_map(lambda x: x[0], state)
# ensuring that we're in a single device setting
assert isinstance(
jax.tree_map(lambda x: x.sharding, jax.tree_leaves(state.params))[0],
jax.sharding.SingleDeviceSharding,
)
I commented this out because sometimes the saved parameters are automatically converted to numpy arrays, leading to errors. Either way works.
so how I should fix the error without uncommenting these lines
If you encounter a similar error, consider updating your code like this. The commented-out code is essentially doing the same thing but includes a type check (which may cause error when state parameters are numpy arrays):
ckpt_path = os.path.join('.', 'ckpt', config.wandb.name, 'time_window_{}'.format(idx + 1))
model.state = restore_checkpoint(model.state, ckpt_path)
model.state = tree_map(lambda x: x[0], model.state)
params = model.state.params
This line of code extrats the first replica of the state, which is duplicated across all devices when being created (see code below). This explains the additional dimension you observed in the state parameters.
def _create_train_state(config):
# Initialize network
arch = _create_arch(config.arch)
x = jnp.ones(config.input_dim)
params = arch.init(random.PRNGKey(config.seed), x)
# Initialize optax optimizer
tx = _create_optimizer(config.optim)
# Convert config dict to dict
init_weights = dict(config.weighting.init_weights)
state = TrainState.create(
apply_fn=arch.apply,
params=params,
tx=tx,
weights=init_weights,
momentum=config.weighting.momentum,
)
return jax_utils.replicate(state)
well. the code goes well for the first time window but for the second I got this error
IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0.
Please go ahead and uncomment the code in restore_checkpoint
. I apologize for the inconvenience caused. I will work on restoring the commented-out code from the source and ensure to add an additional type check for NumPy arrays
Ok, Thanks for your support
when I ran the code, the following error appears: