exoplanet-dev / jaxoplanet

Astronomical time series analysis with JAX
https://jax.exoplanet.codes
MIT License
32 stars 11 forks source link

feat: diagonal sparse Pijk + discussion on merging limb-dark and starry #203

Open lgrcia opened 1 month ago

lgrcia commented 1 month ago

This PR is an attempt to understand what can be done to merge the limb-darkening light curve implementation with starry.

Motivation

The idea is that, if only limb-darkening is present on a map defined with starry, JAX should figure out which part of the different matrices are necessary to compute the starry light-curve, and that these should be pretty similar in number to the ones involved in core.limb_dark.light_curve. Hence performances of the starry light curves should be similar to the limb-dark one, i.e. experimental.starry.light_curves.surface_light_curve with

surface = Surface(y=None, u=u)

Description of the current modifs

In the Ylm and Pijk interfaces, we defined a diagonal attribute which indicates if the non-radial coefficients of the spherical harmonics basis, i.e. not contributing to limb-darkening, are all zeros. The first commit in this PR fixes that for the Pijk basis, and allows to pass a diagonal kwarg to the Pijk.to_sparse method, so that the computation

# see end of jaxoplanet/experimental/starry/light_curves.py

p_y.tosparse(diagonal=only_u) @ design_matrix_p

takes into account the sparsity of p_y.

Results

For now this is not working. The starry light curve is ~ 4 times slower than the limb-dark one.

starry

import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np

r = 0.1
u = (0.1, 0.2)
b = np.linspace(0, 1 + r, 1000)
order = 20

surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
439 μs ± 2.82 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

limb-dark

from jaxoplanet.core.limb_dark import light_curve

function = jax.jit(jax.vmap(lambda b: light_curve(u, b, r, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
115 μs ± 2.04 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

The evaluation time is widely dominated by the call to solution_vector on both ends. But I wonder how to write the starry light curve function so that only relevant parts of the solution vectors are computed if only limb-darkening is present.

Something I noticed is that with a scalar b (no vmap) we get 9.42 μs vs. 7.49 μs.

lgrcia commented 1 month ago

I added a diagonal kwarg to the solution_vector of starry, which lead to skipping the computation for m != 0. It works pretty well but I'd like to understand why the vmap is still that slow and if there is anything we can do.

Here is the new benchmark:

import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np

r = 0.1
u = (0.1, 0.2)
b = np.linspace(0, 1 + r, 1000)
order = 20

surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))

from jaxoplanet.core.limb_dark import light_curve

function = jax.jit(lambda b: light_curve(u, b, r, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
206 μs ± 4.74 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
118 μs ± 2.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

But for a single b:

8.26 μs ± 56.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.33 μs ± 11 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
soichiro-hattori commented 1 month ago

I was playing around yesterday and noticed that I got a roughly 20% speed increase when I switched the order of the jit and vmap! So vmapping on the jitted function.

lgrcia commented 1 month ago

Thanks @soichiro-hattori!

Also, my approach is very wrong... the solution vector is in Green's basis so I cannot do what I did the way I did it. I'll work on this.

soichiro-hattori commented 1 month ago

Never mind I lied @lgrcia!

lgrcia commented 1 month ago

I think I managed to skip computation of the non-diagonal terms. The green's basis is actually pretty similar to the polynomial basis, so I computed which indices $(l,m)$ correspond to the off-diagonal terms (in the polynomial basis) and skipped the computation of these terms in solution_vector (both in p_integral and q_integral).

Here are some benchmark:

import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.solution import solution_vector

r = 0.1
b = 0.1
order = 20

function = jax.jit(solution_vector(2, order, diagonal=False))
jax.block_until_ready(function(b, r))
print("all terms computed")
%timeit jax.block_until_ready(function(b, r))

function = jax.jit(solution_vector(2, order, diagonal=True))
jax.block_until_ready(function(b, r))
print("\nOnly diagonal terms")
%timeit jax.block_until_ready(function(b, r))

from jaxoplanet.core.limb_dark import solution_vector

function = jax.jit(solution_vector(2, order))
print("\ncore.limb_dark version")
jax.block_until_ready(function(b, r))
%timeit jax.block_until_ready(function(b, r))
all terms computed
6.44 μs ± 21.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Only diagonal terms
6.05 μs ± 24.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

core.limb_dark version
5.46 μs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

But if we consider a vmaped version:

import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.solution import solution_vector

r = 0.1
b = jax.numpy.linspace(0, 1 + r, 1000)
order = 20

function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order, diagonal=False)), (0, None)))
jax.block_until_ready(function(b, r))
print("all terms computed")
%timeit jax.block_until_ready(function(b, r))

function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order, diagonal=True)), (0, None)))
jax.block_until_ready(function(b, r))
print("\nOnly diagonal terms")
%timeit jax.block_until_ready(function(b, r))

from jaxoplanet.core.limb_dark import solution_vector

function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order)), (0, None)))
print("\ncore.limb_dark version")
jax.block_until_ready(function(b, r))
%timeit jax.block_until_ready(function(b, r))
all terms computed
295 μs ± 1.04 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Only diagonal terms
190 μs ± 4.2 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

core.limb_dark version
74.5 μs ± 439 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

It'll be nice to know where the difference is coming from.

dfm commented 1 month ago

I'm starting to take a look at this and we'll see how far I get. To begin with, I'm looking at the jaxpr for each calculation. I started with just the zeroth order computation, which should be identical in both cases, and added jax.make_jaxpr(function)(b, r) after each benchmark. Here are the two jaxprs:

starry ``` let _where = { lambda ; a:bool[] b:f64[] c:i64[]. let d:f64[] = convert_element_type[new_dtype=float64 weak_type=True] c e:f64[] = select_n a d b in (e,) } in let _where1 = { lambda ; f:bool[] g:i64[] h:f64[]. let i:f64[] = convert_element_type[new_dtype=float64 weak_type=True] g j:f64[] = select_n f h i in (j,) } in { lambda ; k:f64[] l:f64[]. let m:f64[1] = pjit[ name=impl jaxpr={ lambda ; n:f64[] o:f64[]. let p:f64[1] = pjit[ name=impl jaxpr={ lambda q:f64[20] r:f64[1,20] s:i64[1]; t:f64[] u:f64[]. let v:f64[] = abs t w:f64[] = abs u x:f64[] = integer_pow[y=2] v y:f64[] = sub w 1.0 z:f64[] = add w 1.0 ba:f64[] = mul y z bb:f64[] = sub 1.0 w bc:f64[] = abs bb bd:bool[] = gt v bc be:f64[] = add 1.0 w bf:bool[] = lt v be bg:bool[] = convert_element_type[new_dtype=bool weak_type=False] bd bh:bool[] = convert_element_type[new_dtype=bool weak_type=False] bf bi:bool[] = and bg bh bj:f64[] = pjit[name=_where jaxpr=_where] bi v 1 bk:f64[] = min w bj bl:f64[] = max w bj bm:f64[] = min bl 1.0 bn:f64[] = max bl 1.0 bo:f64[] = min bk bm bp:f64[] = max bk bm bq:f64[] = add bp bn br:f64[] = add bo bq bs:f64[] = sub bo bp bt:f64[] = sub bn bs bu:f64[] = mul br bt bv:f64[] = sub bo bp bw:f64[] = add bn bv bx:f64[] = mul bu bw by:f64[] = sub bp bn bz:f64[] = add bo by ca:f64[] = mul bx bz cb:f64[] = max 0.0 ca cc:f64[] = custom_jvp_call[ call_jaxpr={ lambda ; cd:f64[]. let ce:f64[] = sqrt cd in (ce,) } jvp_jaxpr_thunk=.memoized at 0x176fe2fc0> num_consts=0 symbolic_zeros=False ] cb cf:f64[] = pjit[name=_where jaxpr=_where] bi cc 0 cg:f64[] = add x ba ch:f64[] = atan2 cf cg ci:f64[] = sub x ba cj:f64[] = atan2 cf ci ck:f64[] = integer_pow[y=2] v cl:f64[] = integer_pow[y=2] w cm:f64[] = mul 4.0 v cn:f64[] = mul cm w co:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] cn cp:bool[] = lt co 2.220446049250313e-15 cq:f64[] = pjit[name=_where jaxpr=_where1] cp 1 cn cr:f64[] = sub 1.0 cl cs:f64[] = sub cr ck ct:f64[] = mul 2.0 v cu:f64[] = mul ct w cv:f64[] = add cs cu cw:f64[] = div cv cq cx:f64[] = max 0.0 cw cy:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] w cz:bool[] = lt cy 2.220446049250313e-15 da:f64[] = sub v w db:f64[] = pjit[name=_where jaxpr=_where1] cz 1 w dc:f64[] = mul 2.0 db dd:f64[] = div da dc de:f64[] = mul 0.5 ch df:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] de dg:f64[20] = mul df q dh:f64[] = mul 0.5 ch di:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] dh dj:f64[20] = add dg di _:f64[20] = cos dj dk:f64[20] = sin dg dl:f64[20] = integer_pow[y=2] dk dm:f64[] = sub 1.0 cl dn:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] cx do:f64[20] = sub dn dl dp:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] cq dq:f64[20] = mul dp do dr:f64[20] = pjit[ name=_where jaxpr={ lambda ; ds:bool[] dt:f64[] du:f64[20]. let dv:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] dt dw:f64[20] = broadcast_in_dim[ broadcast_dimensions=() shape=(20,) ] dv dx:f64[20] = select_n ds du dw in (dx,) } ] cp dm dq dy:f64[20] = max 0.0 dr dz:f64[20] = pow dy 1.5 ea:f64[20] = integer_pow[y=2] dl eb:f64[20] = sub dl ea ec:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] dd ed:f64[20] = add ec dl ee:f64[20] = pjit[ name=_where jaxpr={ lambda ; ef:bool[] eg:i64[] eh:f64[20]. let ei:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] eg ej:f64[20] = broadcast_in_dim[ broadcast_dimensions=() shape=(20,) ] ei ek:f64[20] = select_n ef eh ej in (ek,) } ] cz 0 ed el:f64[20] = mul 2.0 dl _:f64[20] = sub 1.0 el em:f64[] = mul 2.0 w en:f64[] = integer_pow[y=-1] em eo:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] en _:f64[20] = mul eo dz ep:f64[] = mul 2.0 w eq:f64[] = integer_pow[y=2] ep er:f64[] = mul 2.0 eq es:f64[20] = pow eb 1.0 et:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] er eu:f64[20] = mul et es ev:f64[20] = pow ee 0.0 ew:f64[20] = mul eu ev ex:f64[1,20] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 20) ] ew ey:f64[1,20] = mul ex r ez:f64[1] = reduce_sum[axes=(1,)] ey fa:f64[] = convert_element_type[ new_dtype=float64 weak_type=False ] de fb:f64[1] = mul fa ez fc:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0.0 fd:i64[1] = device_put[devices=[None] srcs=[None]] s fe:bool[1] = lt fd 0 ff:i64[1] = add fd 1 fg:i64[1] = select_n fe fd ff fh:i32[1] = convert_element_type[ new_dtype=int32 weak_type=False ] fg fi:i32[1,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1, 1) ] fh fj:f64[1] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] fc fi fb fk:f64[] = sub 1.5707963267948966 cj fl:f64[] = cos fk fm:f64[] = sin fk fn:f64[] = mul 2.0 fk fo:f64[] = add fn 3.141592653589793 _:f64[] = mul -2.0 fl fp:f64[] = integer_pow[y=1] fl fq:f64[] = mul 2.0 fp fr:f64[] = integer_pow[y=1] fm fs:f64[] = mul fq fr ft:f64[] = mul 1.0 fo fu:f64[] = add fs ft fv:f64[] = div fu 2.0 fw:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] fv fx:f64[1] = convert_element_type[ new_dtype=float64 weak_type=False ] fw fy:f64[1] = sub fx fj in (fy,) } ] n o in (p,) } ] k l in (m,) } ```
limb_dark ``` let _where = { lambda ; a:bool[] b:f64[] c:f64[]. let d:f64[] = select_n a c b in (d,) } in let _where1 = { lambda ; e:bool[] f:f64[] g:f64[]. let h:f64[] = select_n e g f in (h,) } in { lambda ; i:f64[] j:f64[]. let k:f64[1] = pjit[ name=impl jaxpr={ lambda ; l:f64[] m:f64[]. let n:f64[] = abs l o:f64[] = abs m p:f64[] = integer_pow[y=2] n q:f64[] = sub o 1.0 r:f64[] = add o 1.0 s:f64[] = mul q r t:f64[] = sub 1.0 o u:f64[] = abs t v:bool[] = gt n u w:f64[] = add 1.0 o x:bool[] = lt n w y:bool[] = convert_element_type[new_dtype=bool weak_type=False] v z:bool[] = convert_element_type[new_dtype=bool weak_type=False] x ba:bool[] = and y z bb:f64[] = pjit[name=_where jaxpr=_where] ba n 1.0 bc:f64[] = min o bb bd:f64[] = max o bb be:f64[] = min bd 1.0 bf:f64[] = max bd 1.0 bg:f64[] = min bc be bh:f64[] = max bc be bi:f64[] = add bh bf bj:f64[] = add bg bi bk:f64[] = sub bg bh bl:f64[] = sub bf bk bm:f64[] = mul bj bl bn:f64[] = sub bg bh bo:f64[] = add bf bn bp:f64[] = mul bm bo bq:f64[] = sub bh bf br:f64[] = add bg bq bs:f64[] = mul bp br bt:f64[] = max 0.0 bs bu:f64[] = custom_jvp_call[ call_jaxpr={ lambda ; bv:f64[]. let bw:f64[] = sqrt bv in (bw,) } jvp_jaxpr_thunk=.memoized at 0x176fe28e0> num_consts=0 symbolic_zeros=False ] bt bx:f64[] = pjit[name=_where jaxpr=_where] ba bu 0.0 by:f64[] = add p s bz:f64[] = atan2 bx by ca:f64[] = sub p s cb:f64[] = atan2 bx ca cc:f64[] = add 1.0 o cd:bool[] = ge n cc ce:f64[] = add 1.0 n cf:bool[] = le ce o cg:bool[] = convert_element_type[new_dtype=bool weak_type=False] cd ch:bool[] = convert_element_type[new_dtype=bool weak_type=False] cf ci:bool[] = or cg ch cj:f64[] = pjit[name=_where jaxpr=_where] ci 1.0 n ck:f64[] = integer_pow[y=2] cj cl:f64[] = integer_pow[y=2] o cm:f64[] = add cj o cn:f64[] = add 1.0 cm co:f64[] = sub 1.0 cm cp:f64[] = mul cn co cq:f64[] = mul 0.5 cl cr:f64[] = mul 2.0 ck cs:f64[] = add cl cr ct:f64[] = mul cq cs cu:f64[] = sub 1.0 cl cv:f64[] = mul 3.141592653589793 cu cw:f64[] = mul 2.0 cv cx:f64[] = sub ct 0.5 cy:f64[] = mul 12.566370614359172 cx cz:f64[] = add cw cy da:f64[] = mul cl bz db:f64[] = add cb da dc:f64[] = mul bx 0.5 dd:f64[] = sub db dc de:f64[] = sub 3.141592653589793 dd df:f64[] = mul 2.0 de dg:f64[] = sub 3.141592653589793 cb dh:f64[] = neg dg di:f64[] = mul 2.0 ct dj:f64[] = mul di bz dk:f64[] = add dh dj dl:f64[] = mul 0.25 bx dm:f64[] = mul 5.0 cl dn:f64[] = add 1.0 dm do:f64[] = add dn ck dp:f64[] = mul dl do dq:f64[] = sub dk dp dr:f64[] = mul 2.0 dq ds:f64[] = add df dr dt:f64[] = mul 4.0 cj du:f64[] = mul dt o dv:f64[] = add cp du dw:bool[] = gt dv du dx:f64[] = pjit[name=_where jaxpr=_where1] dw cv de dy:f64[] = pjit[name=_where jaxpr=_where1] dw cz ds dz:f64[] = pjit[name=_where jaxpr=_where1] cd 3.141592653589793 dx ea:f64[] = pjit[name=_where jaxpr=_where1] cf 0.0 dz _:f64[] = pjit[name=_where jaxpr=_where] ci 0.0 dy eb:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ea in (eb,) } ] i j in (k,) } ```

They seem to start off the same, but in the middle I'm seeing a lot more device_puts and scatters in the starry version, so I'm going to see if I can track those down!

dfm commented 1 month ago

Oh - I think I know what it is! We should probably just merge the two implementations: we should use the closed form solution for s0 and s2, and then keep the numerical solutions for the others. I still think we might be able to optimize the implementation in starry for the numerical solutions too though. Where are those scatters coming from?!

dfm commented 1 month ago

See https://github.com/exoplanet-dev/jaxoplanet/pull/204

lgrcia commented 1 month ago

The matrices involved in the light curve computation from Agol 2019 (polynomial limb-darkening) and Luger 2019 (more general) are inherently of different sizes, even if Luger matrices are reduced to the minimal case of a limb-darkened surface (see https://github.com/exoplanet-dev/jaxoplanet/pull/204#issuecomment-2323046728).

So I suspect we will never really have the same performances between the two. But maybe I'm wrong.

One thing I did to bridge the gap is to compute the change of basis matrix from Agol's Green's basis to Luger's polynomial basis, so that the smaller solution vector from Agol can be used directly in the limb-darkened case of a starry surface. Combined with excluding the non-diagonal values in Luger's matrices, performances are much better. But again, not sure if we can push it further.

With these changes, processing times for single values of b and r are equal! But I still don't get why the vmapped version of starry acts differently (although having now closer performances to the limbdark one).

Benchmark of the new version

vmapped

import jax

jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np

r = 1.
u = (1.0, 1.0)
b = np.linspace(0, 1 + r, 1000)
order = 20

surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))

from jaxoplanet.core.limb_dark import light_curve

function = jax.jit(jax.vmap(lambda b: light_curve(u, b, r, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
207 μs ± 10.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
132 μs ± 6.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

This time difference is the worst we will get, as the processing time gets increasingly more comparable for increasing values of order and increased degree of polynomial limb-darkening law.

Single b

import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np

r = 0.1
u = (0.1, 0.2)
b = 0.1
order = 20

surface = Surface(u=u)
function = jax.jit(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))

from jaxoplanet.core.limb_dark import light_curve

function = jax.jit(lambda b: light_curve(u, b, r, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
7.27 μs ± 68.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.42 μs ± 80.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

This time, the evaluation for a single b is equal, and the fact that we don't get that on the vmapped version is still a mystery...

I think these changes contain the idea of #204, since we are working with the minimal size of starry matrices possible (without the zeros). Looking forward to hearing your ideas on this.