QuantEcon / lecture-jax

Lectures on Quantitative Economics Using JAX
https://jax.quantecon.org/
28 stars 4 forks source link

Use callback in `successive_approx` #146

Closed kp992 closed 3 months ago

kp992 commented 3 months ago

Use callback in successive_approx and create JAX-jitted successive_approx

netlify[bot] commented 3 months ago

Deploy Preview for incomparable-parfait-2417f8 ready!

Name Link
Latest commit e0b4871f56214ba26af972ee273279962b04b374
Latest deploy log https://app.netlify.com/sites/incomparable-parfait-2417f8/deploys/65e9532be7116c000874bea6
Deploy Preview https://deploy-preview-146--incomparable-parfait-2417f8.netlify.app
Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

github-actions[bot] commented 3 months ago

🚀 Deployed on https://65e95677064fd2a8349dd068--incomparable-parfait-2417f8.netlify.app

kp992 commented 3 months ago

@jstac This PR adds callback in successive_approx, and removes its jitting. Do you think this is good?

jstac commented 3 months ago

Thanks @kp992 . I think this style is consistent with other solution methods, such as https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.optimize.minimize.html

The function that solves the problem receives as arguments the function that it acts on, an initial guess, and other remaining parameters.

Is it necessary to remove jitting?

Does performance change by much?

kp992 commented 3 months ago

Thanks for the review @jstac!

I have added the successive_approx_jax with jitting and also having the same signature as previously (i.e. Operator callback). Performance wise its better than the one which wasn't using jitting.

jstac commented 3 months ago

Thanks @kp992 !

Can you tell me why this works for the new version: static_argnums=(0,) .

This means that the callable is a "compile-time constant"? I don't get it...

kp992 commented 3 months ago

Thanks @jstac. Please see https://github.com/google/jax/issues/1443#issuecomment-542431823 for the reference.

A "static argument" means that (1) the argument can be any Python object, e.g. a callable, and (2) recompilation is triggered for every new value of the argument (based on eq/hash if the object is hashable, or object identity if it's not).

kp992 commented 3 months ago

There is also one more approach using https://github.com/google/jax/issues/1443#issuecomment-1527813792. If you want, I can try this too.

jstac commented 3 months ago

This is perfect, thanks @kp992 . Let's stick with this one.