Open lgrcia opened 1 month ago
I added a diagonal
kwarg to the solution_vector
of starry, which lead to skipping the computation for m != 0
. It works pretty well but I'd like to understand why the vmap is still that slow and if there is anything we can do.
Here is the new benchmark:
import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np
r = 0.1
u = (0.1, 0.2)
b = np.linspace(0, 1 + r, 1000)
order = 20
surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
from jaxoplanet.core.limb_dark import light_curve
function = jax.jit(lambda b: light_curve(u, b, r, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
206 μs ± 4.74 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
118 μs ± 2.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
But for a single b
:
8.26 μs ± 56.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.33 μs ± 11 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
I was playing around yesterday and noticed that I got a roughly 20% speed increase when I switched the order of the jit
and vmap
! So vmapping on the jitted function.
Thanks @soichiro-hattori!
Also, my approach is very wrong... the solution vector is in Green's basis so I cannot do what I did the way I did it. I'll work on this.
Never mind I lied @lgrcia!
I think I managed to skip computation of the non-diagonal terms. The green's basis is actually pretty similar to the polynomial basis, so I computed which indices $(l,m)$ correspond to the off-diagonal terms (in the polynomial basis) and skipped the computation of these terms in solution_vector
(both in p_integral
and q_integral
).
Here are some benchmark:
import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.solution import solution_vector
r = 0.1
b = 0.1
order = 20
function = jax.jit(solution_vector(2, order, diagonal=False))
jax.block_until_ready(function(b, r))
print("all terms computed")
%timeit jax.block_until_ready(function(b, r))
function = jax.jit(solution_vector(2, order, diagonal=True))
jax.block_until_ready(function(b, r))
print("\nOnly diagonal terms")
%timeit jax.block_until_ready(function(b, r))
from jaxoplanet.core.limb_dark import solution_vector
function = jax.jit(solution_vector(2, order))
print("\ncore.limb_dark version")
jax.block_until_ready(function(b, r))
%timeit jax.block_until_ready(function(b, r))
all terms computed
6.44 μs ± 21.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Only diagonal terms
6.05 μs ± 24.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
core.limb_dark version
5.46 μs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
But if we consider a vmaped version:
import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.solution import solution_vector
r = 0.1
b = jax.numpy.linspace(0, 1 + r, 1000)
order = 20
function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order, diagonal=False)), (0, None)))
jax.block_until_ready(function(b, r))
print("all terms computed")
%timeit jax.block_until_ready(function(b, r))
function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order, diagonal=True)), (0, None)))
jax.block_until_ready(function(b, r))
print("\nOnly diagonal terms")
%timeit jax.block_until_ready(function(b, r))
from jaxoplanet.core.limb_dark import solution_vector
function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order)), (0, None)))
print("\ncore.limb_dark version")
jax.block_until_ready(function(b, r))
%timeit jax.block_until_ready(function(b, r))
all terms computed
295 μs ± 1.04 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Only diagonal terms
190 μs ± 4.2 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
core.limb_dark version
74.5 μs ± 439 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
It'll be nice to know where the difference is coming from.
I'm starting to take a look at this and we'll see how far I get. To begin with, I'm looking at the jaxpr for each calculation. I started with just the zeroth order computation, which should be identical in both cases, and added jax.make_jaxpr(function)(b, r)
after each benchmark. Here are the two jaxprs:
They seem to start off the same, but in the middle I'm seeing a lot more device_put
s and scatter
s in the starry version, so I'm going to see if I can track those down!
Oh - I think I know what it is! We should probably just merge the two implementations: we should use the closed form solution for s0
and s2
, and then keep the numerical solutions for the others. I still think we might be able to optimize the implementation in starry for the numerical solutions too though. Where are those scatter
s coming from?!
The matrices involved in the light curve computation from Agol 2019 (polynomial limb-darkening) and Luger 2019 (more general) are inherently of different sizes, even if Luger matrices are reduced to the minimal case of a limb-darkened surface (see https://github.com/exoplanet-dev/jaxoplanet/pull/204#issuecomment-2323046728).
So I suspect we will never really have the same performances between the two. But maybe I'm wrong.
One thing I did to bridge the gap is to compute the change of basis matrix from Agol's Green's basis to Luger's polynomial basis, so that the smaller solution vector from Agol can be used directly in the limb-darkened case of a starry surface. Combined with excluding the non-diagonal values in Luger's matrices, performances are much better. But again, not sure if we can push it further.
With these changes, processing times for single values of b
and r
are equal! But I still don't get why the vmapped version of starry acts differently (although having now closer performances to the limbdark one).
import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np
r = 1.
u = (1.0, 1.0)
b = np.linspace(0, 1 + r, 1000)
order = 20
surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
from jaxoplanet.core.limb_dark import light_curve
function = jax.jit(jax.vmap(lambda b: light_curve(u, b, r, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
207 μs ± 10.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
132 μs ± 6.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
This time difference is the worst we will get, as the processing time gets increasingly more comparable for increasing values of order
and increased degree of polynomial limb-darkening law.
b
import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np
r = 0.1
u = (0.1, 0.2)
b = 0.1
order = 20
surface = Surface(u=u)
function = jax.jit(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
from jaxoplanet.core.limb_dark import light_curve
function = jax.jit(lambda b: light_curve(u, b, r, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
7.27 μs ± 68.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.42 μs ± 80.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
This time, the evaluation for a single b
is equal, and the fact that we don't get that on the vmapped version is still a mystery...
I think these changes contain the idea of #204, since we are working with the minimal size of starry matrices possible (without the zeros). Looking forward to hearing your ideas on this.
This PR is an attempt to understand what can be done to merge the limb-darkening light curve implementation with starry.
Motivation
The idea is that, if only limb-darkening is present on a map defined with starry, JAX should figure out which part of the different matrices are necessary to compute the starry light-curve, and that these should be pretty similar in number to the ones involved in
core.limb_dark.light_curve
. Hence performances of the starry light curves should be similar to the limb-dark one, i.e.experimental.starry.light_curves.surface_light_curve
withDescription of the current modifs
In the
Ylm
andPijk
interfaces, we defined adiagonal
attribute which indicates if the non-radial coefficients of the spherical harmonics basis, i.e. not contributing to limb-darkening, are all zeros. The first commit in this PR fixes that for thePijk
basis, and allows to pass adiagonal
kwarg to thePijk.to_sparse
method, so that the computationtakes into account the sparsity of
p_y
.Results
For now this is not working. The starry light curve is ~ 4 times slower than the limb-dark one.
starry
limb-dark
The evaluation time is widely dominated by the call to
solution_vector
on both ends. But I wonder how to write the starry light curve function so that only relevant parts of the solution vectors are computed if only limb-darkening is present.Something I noticed is that with a scalar
b
(no vmap) we get9.42 μs
vs.7.49 μs
.