patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
333 stars 14 forks source link

Accessing the (success/not) result of the solver #72

Open djbower opened 3 months ago

djbower commented 3 months ago

Hopefully this is trivial, but I'm having trouble accessing the success/not result of the solver:

Code snippet:

solver = optx.Newton(rtol=tol, atol=tol)
sol = optx.root_find(
    self.objective_function,
    solver,
    initial_solution_guess,
    args=(constraints,),
)
print("sol = ", sol)
result = sol.result
print("result = ", optx.RESULTS[result])

Output of print statements, note that nothing is returned for result = :

sol =  Solution(
  value=f64[3],
  result=EnumerationItem(
    _value=i32[],
    _enumeration=<class 'optimistix._solution.RESULTS'>
  ),
  aux=None,
  stats={'max_steps': 256, 'num_steps': i64[]},
  state=_NewtonChordState(
    f=f64[3],
    linear_state=None,
    diff=f64[3],
    diffsize=f64[],
    diffsize_prev=f64[],
    result=EnumerationItem(
      _value=i32[],
      _enumeration=<class 'optimistix._solution.RESULTS'>
    ),
    step=i64[]
  )
)
result =  

I was expecting to see "successful" printed as per https://docs.kidger.site/optimistix/api/solution/ (I know the returned result is correct because I have a test suite I am comparing against and the numerical value is correct). I can access a '0' integer if I instead access the result._value attribute, but so far no luck getting the string message. I suspect I am using enums incorrectly.

patrick-kidger commented 3 months ago

It's actually just that sucess is indicated with no message whatsoever! The messages are only used to indicate error states.

As an example of a solve failing, try running this (which is a problem that does not have a root):

import optimistix as optx

sol = optx.root_find(
    lambda y, args: (1 + y**2), optx.Newton(rtol=1e-3, atol=1e-3), 1.0, throw=False
)
print(sol)
print("\n-----\n")
print(sol.result)