Open kmzzhang opened 2 years ago
@fbartolic I timed it on my machine and they're about the same for 200 points but ~6x faster for 250,000 points. Perhaps overhead somewhere? I'm working on adding a stopping criteria for Newton's method (and Laguerre's later) and changing the coefficient parametrization now.
Thanks for working on this! You should probably replace lax.scan
in find_isolated_root
with a lax.while
loop. I've avoided using lax.while
loops in caustics
so far because they're not reverse mode differentiable but in this case that doesn't matter because we can use implicit differentiation to define a "custom_jvp" for the solver. I don't think there's need
for the Laguerre update because the root solver is always executed sequentially, initialized with the roots from the previous
point on the source limb. As for improving performance of JAX code, in my experience it's often quite mysterious which strategy will lead to a speedup. Keep in mind that everything is JIT compiled to XLA which performs all various optimizations of the code you write.
The new roots function should first be integrated into images_point_source
in point_source_magnification.py
and take
an optional set of initialization points (for the isolated root in this case). At the moment I'm calling this function within
a lax.scan
loop in the function _eval_images_sequentially
in extended_source_magnification.py
but I will remove this tonight and instead define another function _images_point_source_sequential
in point_source_magnification.py
. The performance of this sequential version is what matters most.
In _images_point_source_sequential
, for the binary lens case I think it's best to first evaluate all isolated roots using one lax.scan
loop and then evaluate the other roots using roots_of_fourth_order_poly
either with vmap
or just make the function work with multi-dimensional arrays with a batch dimension by default. You definitely don't want to evaluate these other roots within a lax.scan
because they don't need to be evaluated sequentially and JAX will automatically parallelize code over multiple CPU cores whenever possible.
I timed your last shared version with "vmap(lambda w: roots_semi_analytic(w, a, e1, niter=15))(w_test)." Solving 250 polynomial took ~1ms but solving 250,000 only took 138ms. As a comparison using poly_roots it took 1000x longer, taking 0.93ms vs 930ms to solve 250 vs 250,000 polynomials. So for large enough tasks it's about 7x faster, but no acceleration for small number of tasks.
I tried putting vmap into a jitted function @partial(jit, static_argnames=("niter")) def find_roots_vmap(w_array, a, e1, niter): return vmap(lambda w: roots_semi_analytic(w, a, e1, niter))(w_array)
This took 0.56ms for solving 250 polynomials, which is ~2x faster than naive vmap. Any idea of better implementation to make it closer to the theoretical 0.13ms? (1/1000th of the time of solving 250,000 polynomials)
Also, did you compare typical root solving time with/without initialization from the roots from the previous point on the source limb?
Sorry for the late response, I won't have time to look into this for another week or so unfortunately.
I'm sure we'll manage to work something out though.
Placeholder pull request for implementing the semi-analytic solver proposed in Zhang (2022) into caustics. Continuation from #16.