patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
310 stars 15 forks source link

correct name of the exception class that Equinox uses for runtime errors #35

Closed TomClarkMassSpec closed 9 months ago

TomClarkMassSpec commented 9 months ago

Hi,

I'm using a couple Equinox pytrees in my program and in one case it is used in conjunction with Newton root finding from Optimistix. My larger code is a gradient descent variation and occasionally a data point will be expected to not have a solution to the root finding algorithm. In trying to set up try: except: , what is the correct name of the exception class that Equinox uses for runtime errors, or should it be something from optimistix?

 instance_of_acceleration = AccelerationPytree(l_pr, regime, kinetic_conservative, rot_dissapative, ld_dissapative, epd_dissapative_1, qe_conservative_1, epd_dissapative_2, epd_dissapative_3, epd_dissapative_4, qe_conservative_2, qe_conservative_3)

    solver_root = optx.Newton(rtol=1e-8, atol=1e-8)
    y0 = (jnp.array(0.1))
    try:
        sol = optx.root_find(fn=time_root_from_distance, solver=solver_root (well_posed=False), y0=y0, args=instance_of_acceleration, options=dict(lower=0.), max_steps=20000, throw=False)
        Thv = sol.value
    except eqx.exception_module.EqxRuntimeError. (WHAT GOES HERE?):

        #Set Thv to a default value or handle it accordingly
        Thv = 999.  # Replace with an appropriate default value or action
    print(Thv)
    return Thv

Please and thanks, Tom

patrick-kidger commented 9 months ago

The best way to do this will be to call as sol = optx.root_find(..., throw=False), and then check whether sol.result == optx.RESULTS.successful.

Based on your example, what you probably want is

Thv = jnp.where(sol.result == optx.RESULTS.successful, sol.value, 999.)

The reason for taking this approach is that a try/except won't work with JITing, so it's almost never what you really want.