Closed SNMS95 closed 1 year ago
jaxopt.optaxsolver
def train_nn_for_five_imgs(): def loss_fn(params, nn_state, target) # calculate MSE return MSE, nn_state_new opt = optax.chain(optax.clip_by_global_norm(grad_clip), optax.adam(learning_rate=lr) ) optimizer = jaxopt.OptaxSolver( fun=loss_fn, opt=opt, value_and_grad=False, has_aux=True, jit=False, verbose=False) for i in range(5): # Create a NN opt_state = optimizer.init_state(init_params, init_nn_state, target_image) for _ in range(max_iterations): new_params, new_opt_state = optimizer.update(params, opt_state, nn_state, target_image) params = new_params opt_state = new_opt_state nn_state = opt_state.aux
See #380 for a potential fix.
jaxopt.optaxsolver
optimizer.