DifferentiableUniverseInitiative / jax_cosmo

A differentiable cosmology library in JAX
MIT License
172 stars 35 forks source link

Implement explicit JVP for integrals #47

Open EiffL opened 4 years ago

EiffL commented 4 years ago

Right now we let JAX fgure out the gradients automatically for integration, that works ok but I think it's causing large memory overhead, and when autodiffing lax.scan calls the graph takes very long to compile. So instead, we probably should implement custom JVP for integrals. Which is easy, it's just the integral of the inner JVP.

I'm openning this issue to document my experimentations with this and gather any thoughts or ideas anyone may have.

To get things started, here is my first attempt at a custom Simpson integration using jax.lax scan and custom JVP:

def my_simps(func, a,b, *args, N=128):
    if N % 2 == 1:
        raise ValueError("N must be an even integer.")
    dx = (b-a)/N
    x = np.linspace(a,b,N+1)
    return _custom_simps(func, x, dx, *args)

@partial(custom_jvp, nondiff_argnums=(0,1,2))
def _custom_simps(func, x, dx, *args):
    f = lambda x: func(x, *args)
    @jax.remat
    def loop_fn(carry, x):
        y = f(x)
        s = 4*y[0] + 2*y[1]
        return carry+s, 0
    r, _ = jax.lax.scan(loop_fn, f(np.atleast_1d(x[0]))[0], x[1:].reshape((-1,2)))
    S = dx/3 * ( r - f(np.atleast_1d(x[-1]))[0])
    return S

@_custom_simps.defjvp
def _custom_simps_jvp(func, x, dx, primals, tangents):
    # Define a new function that computes the jvp
    f = lambda x: jax.jvp(lambda *args:func(x, *args), primals, tangents)

    def loop_fn(carry, x):
        c, *args=carry
        s1 = f(x[0])
        s2 = f(x[1])
        return jax.tree_multimap(lambda a,b,c:a+4*b+2*c, carry, s1,s2), 0

    r, _ = jax.lax.scan(loop_fn, f(x[0]), x[1:].reshape((-1,2)))
    S = jax.tree_multimap(lambda a,b: dx/3 * (a-b), r, f(x[-1]))
    return S

It seems to work in simple examples, but is still hitting a strange issue in the lax.scan_tranpose function used in reverse mode AD

eelregit commented 4 years ago

For Simpson's rule, is it possible to autodiff dx/3 * sum(y[0:-2:2] + 4*y[1::2] + y[2::2])?

EiffL commented 4 years ago

oh yeah for sure, that's what I'm curently doing

eelregit commented 4 years ago

I see! Would the integral method of the spline be any faster?

EiffL commented 4 years ago

Not really, it's more or less the same complexity. I think it's the same

EiffL commented 4 years ago

So the main thing to consider is about Batching. For the stuff I want I do is for instance sample 128 different cosmologies at once, for instance to run 128 chains in parallel. Let's say I want to compute some lensing cl with the simpson formula you had above, at some point in the calculation, I will have a tensor of shape:

[batch, nCls, n_ell, n_a1, n_a2]

where nCls is the number of cross-spectra, n_ell is the number of ells, n_a1 is the number of points in a in the limber integral evaluated by simps, and n_a2 is the number of points in a in integral of the lensing kernel.

So it quickly makes pretty big arrays, especially due to the memory overhead of the grads.

That's why I wanted an integration method the used lax.scan, so an actual with a for loop, to compute some integrals sequentially instead of in parallel, to save memory

eelregit commented 4 years ago

Thanks for the explanation! I was hoping spline might need fewer knots, even both are O(n) complexity.

In this large memory limit, do you expect speedup by batching as in #32? Or maybe batching is not just for efficiency? Actually in #32 I don't understand why it's a gain if the time are the same. Sorry for so many questions :)

EiffL commented 4 years ago

Ah yes they do, check out our fresh Spline PR #54 , far fewer points are necessary. So this may do the trick. I have implemented spline integration, but havent tested it in this bach setting yet.

So, because things happen in parallel, running a batch of size 1 is the same speed as running a batch of size 128, if you can afford it. So if running a batch takes 1s, you can get 128 different cosmologies in 1s instead of 1.

I care about batching because I want to use SVI ^^' or train neural networks with gradients that go through cosmology calculations.

Absolutely no problem :-) keep asking! If I can get you interested that would be awesome :-)