gboehl / econpizza

Solve nonlinear heterogeneous agent models
MIT License
74 stars 15 forks source link

enhancement: support for jaxopt as a more robust stst solver #13

Closed drhboss closed 3 months ago

drhboss commented 4 months ago

I'm trying to solve a rather complex model (~100 equations), the problem is that the Newton solver is not robust enough if the initial values are a bit off, it wouldn't converge, is there any chance to implement more robust solvers or add support to utilize optimization routines from jaxopt to solve the model?

gboehl commented 4 months ago

Hi. Hm, I can't think of an application where the Newton-with-pseudoinverse solver would not be the best/most robust choice, but maybe it's just me. Did you check if the rank of the jacobian and the number of provided steady state values is aligned? This is normally the problem. I do not think that jaxopt/optax will be of much help because it rather targets optimization. Their root finding routines are pretty basic. But I'm happy to be convinced of the opposite.

Custom solvers are currently not supported, but I agree that this is a nice feature. What you could do is, you could take the steady state function and try to solve it with the solver of your choice. The steady state function gets compiled when running solve_stst and is then added to the model's context:

https://github.com/gboehl/econpizza/blob/master/econpizza/solvers/steady_state.py#L121

So you could do is:

#...parse your model. Then

# needed below
from econpizza.parser import d2jnp

# compile function but fail
res_stst = mod.solve_stst(raise_errors=False)

# initial values contain all variables and parameters that are not in fixed_values
init_vals = d2jnp(res_stst['initial_values']['guesses'])
# obtain the steady state function
stst_fun = mod['context']['func_stst']
# solve it
stst = your_root_finder(stst_fun, init_vals)

You should then either run solve_stst again with the correct initial values, or update mod['stst'] and mod['pars'] accordingly (see here). Let me know if this helps. Otherwise adding custom solvers should be quite straightforward.

drhboss commented 3 months ago

Thank you for the notes, I was able to plugin my own solver and pass it to your newton_jax stst solver. everything works fine now.


def solve_singular(A, b):
    U, s, V = jnp.linalg.svd(A)
    # Reciprocal condition number threshold
    threshold = jnp.finfo(float).eps * max(A.shape) * s[0]
    s_inv = jnp.where(s > threshold, 1 / s, 0)
    pseudo_inverse = jnp.dot(V.T, s_inv[:, None] * U.T)
    x = jnp.dot(pseudo_inverse, b)
    return x
gboehl commented 3 months ago

Great that it worked. Just out of interest, I was wondering what the difference between your solver and mine was. We both do nothing else than calculating the pseudo inverse. Is your threshold different?

drhboss commented 3 months ago

that's right, they are essentially the same solver, this way it allows me to modify the threshold manually when the model doesn't solve, however, it is not working all the times. I do have a model that none of the solvers can solve it, but I have a solution obtained manually, if interested, I can share it with you.

gboehl commented 3 months ago

Sure!