fbartolic / caustics

Differentiable microlensing powered by JAX
https://fbartolic.github.io/caustics/
MIT License
13 stars 2 forks source link

Analytic Solver #17

Open kmzzhang opened 2 years ago

kmzzhang commented 2 years ago

Placeholder pull request for implementing the semi-analytic solver proposed in Zhang (2022) into caustics. Continuation from #16.

kmzzhang commented 2 years ago

@fbartolic I timed it on my machine and they're about the same for 200 points but ~6x faster for 250,000 points. Perhaps overhead somewhere? I'm working on adding a stopping criteria for Newton's method (and Laguerre's later) and changing the coefficient parametrization now.

fbartolic commented 2 years ago

Thanks for working on this! You should probably replace lax.scan in find_isolated_root with a lax.while loop. I've avoided using lax.while loops in caustics so far because they're not reverse mode differentiable but in this case that doesn't matter because we can use implicit differentiation to define a "custom_jvp" for the solver. I don't think there's need for the Laguerre update because the root solver is always executed sequentially, initialized with the roots from the previous point on the source limb. As for improving performance of JAX code, in my experience it's often quite mysterious which strategy will lead to a speedup. Keep in mind that everything is JIT compiled to XLA which performs all various optimizations of the code you write.

The new roots function should first be integrated into images_point_source in point_source_magnification.py and take an optional set of initialization points (for the isolated root in this case). At the moment I'm calling this function within a lax.scan loop in the function _eval_images_sequentially in extended_source_magnification.py but I will remove this tonight and instead define another function _images_point_source_sequential in point_source_magnification.py. The performance of this sequential version is what matters most.

In _images_point_source_sequential, for the binary lens case I think it's best to first evaluate all isolated roots using one lax.scan loop and then evaluate the other roots using roots_of_fourth_order_poly either with vmap or just make the function work with multi-dimensional arrays with a batch dimension by default. You definitely don't want to evaluate these other roots within a lax.scan because they don't need to be evaluated sequentially and JAX will automatically parallelize code over multiple CPU cores whenever possible.

kmzzhang commented 2 years ago

I timed your last shared version with "vmap(lambda w: roots_semi_analytic(w, a, e1, niter=15))(w_test)." Solving 250 polynomial took ~1ms but solving 250,000 only took 138ms. As a comparison using poly_roots it took 1000x longer, taking 0.93ms vs 930ms to solve 250 vs 250,000 polynomials. So for large enough tasks it's about 7x faster, but no acceleration for small number of tasks.

I tried putting vmap into a jitted function @partial(jit, static_argnames=("niter")) def find_roots_vmap(w_array, a, e1, niter): return vmap(lambda w: roots_semi_analytic(w, a, e1, niter))(w_array)

This took 0.56ms for solving 250 polynomials, which is ~2x faster than naive vmap. Any idea of better implementation to make it closer to the theoretical 0.13ms? (1/1000th of the time of solving 250,000 polynomials)

Also, did you compare typical root solving time with/without initialization from the roots from the previous point on the source limb?

fbartolic commented 2 years ago

Sorry for the late response, I won't have time to look into this for another week or so unfortunately.

I'm sure we'll manage to work something out though.