Open jwson302 opened 3 months ago
Thank you for your interest in our work! To incorporate your inverse problems, you'll need to make some modifications to the source code. Specifically, you can add your inverse parameters as additional key-value pairs in the params dictionary. Here’s a possible approach:
def _create_train_state(config):
# Initialize network
arch = _create_arch(config.arch)
x = jnp.ones(config.input_dim)
nn_params = arch.init(random.PRNGKey(config.seed), x)
pde_params = 0.0
params = {'nn_params': nn_params, 'pde_params': pde_params}
# Initialize optax optimizer
tx = _create_optimizer(config.optim)
# Convert config dict to dict
init_weights = dict(config.weighting.init_weights)
state = TrainState.create(
apply_fn=arch.apply,
params=params,
tx=tx,
weights=init_weights,
momentum=config.weighting.momentum,
)
return jax_utils.replicate(state)
You’ll also need to ensure that these parameters are extracted separately within the model's forward loss function.
Hope this helps!
I am trying to solve an inverse problem with the jaxpi package. How would I go about defining the trainable inverse parameter? Would I need to change the source code of the library?