google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
903 stars 62 forks source link

Garbage collection issues #548

Closed SNMS95 closed 7 months ago

SNMS95 commented 8 months ago
def train_nn_for_five_imgs():
    def loss_fn(params, nn_state, target)
          # calculate MSE
          return MSE, nn_state_new

     opt = optax.chain(optax.clip_by_global_norm(grad_clip),
                      optax.adam(learning_rate=lr)
                      )
    optimizer = jaxopt.OptaxSolver(
                                fun=loss_fn,
                                opt=opt,
                                value_and_grad=False,
                                has_aux=True,
                                jit=False,
                                verbose=False)
    for i in range(5):
         # Create a NN
         opt_state = optimizer.init_state(init_params, init_nn_state,
                                         target_image)
         for _ in range(max_iterations):
               new_params, new_opt_state = optimizer.update(params, opt_state,
                                                         nn_state, target_image)
              params = new_params
              opt_state = new_opt_state
              nn_state = opt_state.aux
mblondel commented 8 months ago

See #380 for a potential fix.