patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
310 stars 15 forks source link

`BestSoFar...` wanted behavior ? #41

Closed thibmonsel closed 7 months ago

thibmonsel commented 7 months ago

Hi Patrick,

I suppose, that the BestSoFar... solvers are working as intended.

import jax.numpy as jnp
import optimistix as optx

def f(x, _):
    return 1 / 2 * x

solv = optx.FixedPointIteration(10e-3, 10e-3)
sol = optx.fixed_point(f, solv, jnp.array(0.3), max_steps=10)
print(sol.value)

solv2 = optx.BestSoFarFixedPoint(solv)
sol2 = optx.fixed_point(f, solv2, jnp.array(0.3), max_steps=3) # shoots out error
print(sol2.value) 

solv2 returns an XlaRuntimeError but could it be possible instead to return the best so far value computed (here it would be at max_steps=2) even if the latest value of sol2 doesn't satisfy there termination condition ?

If not, would there be a possibility to get the best so for value computed without throwing a runtime error ?

patrick-kidger commented 7 months ago

Yup -- pass fixed_point(..., throw=False).