Closed kp992 closed 3 months ago
Name | Link |
---|---|
Latest commit | e0b4871f56214ba26af972ee273279962b04b374 |
Latest deploy log | https://app.netlify.com/sites/incomparable-parfait-2417f8/deploys/65e9532be7116c000874bea6 |
Deploy Preview | https://deploy-preview-146--incomparable-parfait-2417f8.netlify.app |
Preview on mobile | Toggle QR Code...Use your smartphone camera to open QR code link. |
To edit notification comments on pull requests, go to your Netlify site configuration.
@jstac This PR adds callback in successive_approx, and removes its jitting. Do you think this is good?
Thanks @kp992 . I think this style is consistent with other solution methods, such as https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.optimize.minimize.html
The function that solves the problem receives as arguments the function that it acts on, an initial guess, and other remaining parameters.
Is it necessary to remove jitting?
Does performance change by much?
Thanks for the review @jstac!
I have added the successive_approx_jax
with jitting and also having the same signature as previously (i.e. Operator callback). Performance wise its better than the one which wasn't using jitting.
Thanks @kp992 !
Can you tell me why this works for the new version: static_argnums=(0,)
.
This means that the callable is a "compile-time constant"? I don't get it...
Thanks @jstac. Please see https://github.com/google/jax/issues/1443#issuecomment-542431823 for the reference.
A "static argument" means that (1) the argument can be any Python object, e.g. a callable, and (2) recompilation is triggered for every new value of the argument (based on eq/hash if the object is hashable, or object identity if it's not).
There is also one more approach using https://github.com/google/jax/issues/1443#issuecomment-1527813792. If you want, I can try this too.
This is perfect, thanks @kp992 . Let's stick with this one.
Use callback in
successive_approx
and create JAX-jittedsuccessive_approx