PredictiveIntelligenceLab / jaxpi

Other
231 stars 52 forks source link

How to define trainable parameter for inverse problems #20

Open jwson302 opened 3 months ago

jwson302 commented 3 months ago

I am trying to solve an inverse problem with the jaxpi package. How would I go about defining the trainable inverse parameter? Would I need to change the source code of the library?

sifanexisted commented 3 months ago

Thank you for your interest in our work! To incorporate your inverse problems, you'll need to make some modifications to the source code. Specifically, you can add your inverse parameters as additional key-value pairs in the params dictionary. Here’s a possible approach:

def _create_train_state(config):
    # Initialize network
    arch = _create_arch(config.arch)
    x = jnp.ones(config.input_dim)
    nn_params = arch.init(random.PRNGKey(config.seed), x) 
    pde_params = 0.0
    params = {'nn_params': nn_params, 'pde_params': pde_params}

    # Initialize optax optimizer
    tx = _create_optimizer(config.optim)

    # Convert config dict to dict
    init_weights = dict(config.weighting.init_weights)

    state = TrainState.create(
        apply_fn=arch.apply,
        params=params,
        tx=tx,
        weights=init_weights,
        momentum=config.weighting.momentum,
    )

    return jax_utils.replicate(state)

You’ll also need to ensure that these parameters are extracted separately within the model's forward loss function.

Hope this helps!