google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

Fix complex support for L-BFGS #1142

Open gautierronan opened 1 day ago

gautierronan commented 1 day ago

Closes https://github.com/google-deepmind/optax/issues/1141.

Not 100% sure that this doesn't break other things or fully works, but at least the MWE below seems to work fine.

import optax
import jax.numpy as jnp

def f(x):
    return jnp.sum(jnp.abs(x**2))

solver = optax.lbfgs()
params = jnp.array([1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j])
print("Objective function: ", f(params))

opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(f)

for _ in range(5):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        jnp.conj(grad), opt_state, params, value=value, grad=jnp.conj(grad), value_fn=f
    )
    params = optax.apply_updates(params, updates)
    print("Objective function: ", f(params))

Notice the solve.update call which requires a jnp.conj(grad) twice. I believe this is correct and aligned with other optax solvers, but not sure either.

vroulet commented 1 day ago

Hey @gautierronan, Thanks for the PR! We'll need a test. Take look at this PR: https://github.com/google/jaxopt/pull/468 that added support for complex parameters for the lbfgs of jaxopt. I think you'll find all that you'll need in that PR. Thanks again!

gautierronan commented 10 hours ago

@vroulet Should be good for review.