google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link

Passing a custom callback to jaxopt.ScipyMinimize #257

Open richinex opened 2 years ago

richinex commented 2 years ago

I would like to ask if there is any way to pass a custom call back function to one of the solvers (TNC) in this case. For example, I am able to make the code below work in scipy.optimize.minimize (BTW, I used jnp in both cases):

# This works
res = scipy.optimize.minimize(self.imp_sim_weighted.simulate, ln_par0, \
                        args = (self.controller.app_data["freq"], self.controller.app_data["z"], lb_col, ub_col, self.smf, self.controller.app_data["weight"]), method = 'TNC', 
                        jac = self.controller.jac_deis, callback= self.imp_sim.callback, options={'maxfun':10000, 'ftol':1e-10, 'xtol':1e-10})

But with jaxopts, I get an error:

solver = jaxopt.ScipyMinimize(method = "TNC", fun=self.imp_sim.simulate, tol = 1e-12, options ={'maxiter':5000, 'callback':self.imp_sim.callback})
            sol = solver.run(ln_par0, self.controller.app_data["freq"], self.controller.app_data["y"], lb_col, ub_col, self.smf, self.controller.app_data["weight"]) 

Error:

#     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
# jax._src.traceback_util.UnfilteredStackTrace: TypeError: scipy.optimize._tnc._minimize_tnc() got multiple values for keyword argument # 'callback'
mblondel commented 2 years ago

Currently it's not supported but this would be nice to have indeed.