DifferentiableUniverseInitiative / jax_cosmo

A differentiable cosmology library in JAX
MIT License
175 stars 37 forks source link

Implement Hankel Transform / FFTLog #52

Open minaskar opened 4 years ago

minaskar commented 4 years ago

In order for jax_cosmo to be used in observational cosmology analyses (e.g. BAO, RSD, fNL) we need a JAX implementation of FFTLog algorithm in order to facilitate Survey Window function convolutions with the Power Spectrum.

This would also be helpful to get models of the correlation function as it was mentioned in another post.

There's already a package that's used very often in cosmology: https://github.com/eelregit/mcfit

It should be possible to implement it using JAX.

EiffL commented 4 years ago

Yes totally agree. Indeed FFTLog also comes up in #30 although I think @sukhdeep1989 wants to implement a different approach for that.

I haven't dived into the details of mcfit before but I'm pretty sure @eelregit would be interested in a JAX version.

Other option that I looked into was to just transpose to JAX this implementation: https://github.com/JoeMcEwen/FAST-PT/blob/master/fastpt/HT.py

Seemed pretty easy at first glance.

sukhdeep2 commented 4 years ago

What I added is for projected correlation functions. Interfacing with mcfit will be great. FYI, hankel transforms and window effects do not need to be differentiable (there is no dependence on parameters). So, I'm not sure if implementation with jax is super important.

On Thu, Jun 11, 2020 at 5:17 AM Francois Lanusse notifications@github.com wrote:

Yes totally agree. Indeed FFTLog also comes up in #30 https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues/30 although I think @sukhdeep1989 https://github.com/sukhdeep1989 wants to implement a different approach for that.

I haven't dived into the details of mcfit before but I'm pretty sure @eelregit https://github.com/eelregit would be interested in a JAX version.

Other option that I looked into was to just transpose to JAX this implementation: https://github.com/JoeMcEwen/FAST-PT/blob/master/fastpt/HT.py

Seemed pretty easy at first glance.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues/52#issuecomment-642606600, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA4EPWWUI26MB6GARB4DJRLRWDDMJANCNFSM4N3LCHAA .

EiffL commented 4 years ago

Ok, so I think what you mean @sukhdeep1989 is that we may not need to write everything in native JAX code, because when integrals are involved, we can compute explicit JVPs and so write custom autodiff rules instead of asking JAX to figure it out. so like:

xi = dk k P(k) J(kr)
xi’ = dk k P’(k) J(kr)

See this other issue #47 I have opened to do this more generally for all integrals, hopefully making the JAX compilation faster than it currently is

eelregit commented 4 years ago

Happy to help if needed!

xi = dk k P(k) J(kr)
xi’ = dk k P’(k) J(kr)

I wonder if the second prime should be with respect to r and thus on J?

sukhdeep2 commented 4 years ago

On Thu, Jun 11, 2020 at 11:01 AM Yin Li notifications@github.com wrote:

Happy to help if needed!

xi = dk k P(k) J(kr)

xi’ = dk k P’(k) J(kr)

I wonder if the second prime should be with respect to r and thus on J?

The prime is with respect to cosmology, from what I understand. kr = ell theta, is independent of cosmology.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues/52#issuecomment-642843202, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA4EPWTFP7S5LQ5CISTKZOLRWELXNANCNFSM4N3LCHAA .

eelregit commented 4 years ago

I see. I agree that the integral and d/d(cosmology) are orthogonal.

Will it suffice to replace numpy by jax.numpy in mcfit?

EiffL commented 4 years ago

It's very possible :-D

eelregit commented 4 years ago

I am thinking what's a good interface to switch between the numpy and jax backends. Also need to read some jax docs.

Any recommendation? :)

EiffL commented 4 years ago

That's a good question. I don't think there is an easy way to do it, or at least not a generic one. One example I know of backend switching between TF and Numpy is here: https://github.com/google/edward2#using-the-numpy-backend Or between TF, NumPy, and JAX: https://github.com/tensorflow/probability/tree/master/tensorflow_probability/python/experimental/substrates But the mechanism seems pretty complicated:

https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/experimental/substrates/meta/rewrite.py

They have a script that rewrite codes on the fly for various backends

eelregit commented 4 years ago

Indeed that looks quite complicated. I will look at them more closely. Otherwise it is straightforward to copy and replace numpy by jax, if you don't mind a PR like that ;)

florpi commented 2 years ago

Hi everyone!

I had to do an implementation of the Hankel transform in jax for a project I was working on, perhaps it is useful here as well. You can find it in this repo https://github.com/florpi/JaxHankel

And actually, maybe you can help me understand why does it work since I'm converting jax arrays to numpy arrays and using a scipy function at one point (see here https://github.com/florpi/JaxHankel/blob/main/jax_fht/fht.py#L61) But the final derivatives seem fine

EiffL commented 2 years ago

Hey Carolina :-D

That looks super useful indeed! To answer your question rightaway, it will work even if you use scipy functions (because implicitly it will convert back and forth between numpy and jax.numpy arrays) until you try to use jax.jit or jax.grad, and that will fail.

But, it looks like at least gammaln is implemented in jax already: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.gammaln.html#jax.scipy.special.gammaln

Wouldn't that be enough for your usecase? I think it's doing the same as loggamma(x) if x is positive.

EiffL commented 2 years ago

ah no sorry, I see now that the argument can be imaginary :-| but then it should "just" be a matter of adding an implemenation of loggamma

eelregit commented 2 years ago

Hey Carolina :)

I agree with Francois. Maybe there's a way to cache scipy loggamma values in jnp.ndarray's, as long as one does not need derivative to the x values (scales).

florpi commented 2 years ago

Hi both! thanks for your comments :) I also thought that if I was to call grad or jacobian it would fail, but you can see that it works on this test https://github.com/florpi/JaxHankel/blob/main/test_jax_fht/test_analytical_cosmology.py#L39 Maybe I'm misunderstanding how jacobian works?

Regarding loggamma, the issue I had was exactly the complex number extension hehe I haven't looked too much into it, so it might not be hard. Any ideas are welcome

EiffL commented 2 years ago

Ha yes, sorry I had missed that you were getting the jacobian.

Yes, so, this works because the loggamma function is only used to compute coefficients for the Hankel transform. These coefficients are fixed, and you don't take derivatives with respect to them, so no problem during jit compililing or taking gradients.

In practice when you jit compile that function, as long as the arguments are fixed, the scipy code will be called, and the results stored as a "constant" that is then used in the rest of the jax code.

So most likely no need to reimplmeent loggamma :-D

florpi commented 2 years ago

Thank you, that makes a lot of sense! Also, not having to reimplement loggamma makes me very happy :D

eelregit commented 2 years ago

I think the kernels (e.g. _u here) are to be evaluated at points that are determined by the input k or r scale (via e.g. Delta in that link). So the kernel values are not jit constants unless we make the input scales static?