google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link

pytrees bounds for `jaxopt.ScipyBoundedMinimize` #601

Closed lgrcia closed 3 months ago

lgrcia commented 3 months ago

When using jaxopt.ScipyBoundedMinimize, if the initial parameters are specified as a dict, how to specify the bounds using a dict-like structure?

The following fails:

init = {"a": 0.1, "b": 0.2}

solver = jaxopt.ScipyBoundedMinimize(fun=fun)
result = solver.run(
    init,
    bounds=(
        {"a": 0.0, "b": 0.0},
        {"a": 1.0, "b": 1.0},
    ),
)

Documentation says: bounds: an optional tuple (lb, ub) of pytrees with structure identical to init_params, representing box constraints, so it's probably my misunderstanding of pytrees structure rather than a bug. Thanks for your help!

lgrcia commented 3 months ago

I think this work as expected and the issue was in the model parameters!