PredictiveIntelligenceLab / jaxpi

Other
231 stars 52 forks source link

error in eval code for ns_unsteady_cylinder example #3

Closed Elappnano closed 9 months ago

Elappnano commented 9 months ago

when I ran the code, the following error appears:

flax.errors.ScopeParamShapeError: Initializer expected to generate shape (1, 3, 128) but got shape (3, 128) instead for parameter "kernel" in "/FourierEmbs_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

sifanexisted commented 9 months ago

Do you meet this error when restoring the checkpoint?

If so, please add state = jax.device_get(tree_map(lambda x: x[0], model.state)) to extract the first replica of the state.

Elappnano commented 9 months ago

No I got this error when run this line

u_pred = u_pred_fn(params, t_coords, coords[:, 0], coords[:, 1])

Elappnano commented 9 months ago

The error is fixed. Thats because these lines in the utils.py file of jaxpi folder were commented.

    if isinstance(
        jax.tree_map(lambda x: x.sharding, jax.tree_leaves(state.params))[0],
        jax.sharding.PmapSharding,
    ):
        state = jax.tree_map(lambda x: x[0], state)

    # ensuring that we're in a single device setting
    assert isinstance(
        jax.tree_map(lambda x: x.sharding, jax.tree_leaves(state.params))[0],
        jax.sharding.SingleDeviceSharding,
    )
sifanexisted commented 9 months ago

I commented this out because sometimes the saved parameters are automatically converted to numpy arrays, leading to errors. Either way works.

Elappnano commented 9 months ago

so how I should fix the error without uncommenting these lines

sifanexisted commented 9 months ago

If you encounter a similar error, consider updating your code like this. The commented-out code is essentially doing the same thing but includes a type check (which may cause error when state parameters are numpy arrays):

        ckpt_path = os.path.join('.', 'ckpt', config.wandb.name, 'time_window_{}'.format(idx + 1))
        model.state = restore_checkpoint(model.state, ckpt_path)
        model.state = tree_map(lambda x: x[0], model.state)
        params = model.state.params

This line of code extrats the first replica of the state, which is duplicated across all devices when being created (see code below). This explains the additional dimension you observed in the state parameters.

def _create_train_state(config):
    # Initialize network
    arch = _create_arch(config.arch)
    x = jnp.ones(config.input_dim)
    params = arch.init(random.PRNGKey(config.seed), x)

    # Initialize optax optimizer
    tx = _create_optimizer(config.optim)

    # Convert config dict to dict
    init_weights = dict(config.weighting.init_weights)

    state = TrainState.create(
        apply_fn=arch.apply,
        params=params,
        tx=tx,
        weights=init_weights,
        momentum=config.weighting.momentum,
    )

    return jax_utils.replicate(state)
Elappnano commented 9 months ago

well. the code goes well for the first time window but for the second I got this error

IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0.

sifanexisted commented 9 months ago

Please go ahead and uncomment the code in restore_checkpoint. I apologize for the inconvenience caused. I will work on restoring the commented-out code from the source and ensure to add an additional type check for NumPy arrays

Elappnano commented 9 months ago

Ok, Thanks for your support