Open minaskar opened 4 years ago
Yes totally agree. Indeed FFTLog also comes up in #30 although I think @sukhdeep1989 wants to implement a different approach for that.
I haven't dived into the details of mcfit before but I'm pretty sure @eelregit would be interested in a JAX version.
Other option that I looked into was to just transpose to JAX this implementation: https://github.com/JoeMcEwen/FAST-PT/blob/master/fastpt/HT.py
Seemed pretty easy at first glance.
What I added is for projected correlation functions. Interfacing with mcfit will be great. FYI, hankel transforms and window effects do not need to be differentiable (there is no dependence on parameters). So, I'm not sure if implementation with jax is super important.
On Thu, Jun 11, 2020 at 5:17 AM Francois Lanusse notifications@github.com wrote:
Yes totally agree. Indeed FFTLog also comes up in #30 https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues/30 although I think @sukhdeep1989 https://github.com/sukhdeep1989 wants to implement a different approach for that.
I haven't dived into the details of mcfit before but I'm pretty sure @eelregit https://github.com/eelregit would be interested in a JAX version.
Other option that I looked into was to just transpose to JAX this implementation: https://github.com/JoeMcEwen/FAST-PT/blob/master/fastpt/HT.py
Seemed pretty easy at first glance.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues/52#issuecomment-642606600, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA4EPWWUI26MB6GARB4DJRLRWDDMJANCNFSM4N3LCHAA .
Ok, so I think what you mean @sukhdeep1989 is that we may not need to write everything in native JAX code, because when integrals are involved, we can compute explicit JVPs and so write custom autodiff rules instead of asking JAX to figure it out. so like:
xi = dk k P(k) J(kr)
xi’ = dk k P’(k) J(kr)
See this other issue #47 I have opened to do this more generally for all integrals, hopefully making the JAX compilation faster than it currently is
Happy to help if needed!
xi = dk k P(k) J(kr) xi’ = dk k P’(k) J(kr)
I wonder if the second prime should be with respect to r and thus on J?
On Thu, Jun 11, 2020 at 11:01 AM Yin Li notifications@github.com wrote:
Happy to help if needed!
xi = dk k P(k) J(kr)
xi’ = dk k P’(k) J(kr)
I wonder if the second prime should be with respect to r and thus on J?
The prime is with respect to cosmology, from what I understand. kr = ell theta, is independent of cosmology.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues/52#issuecomment-642843202, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA4EPWTFP7S5LQ5CISTKZOLRWELXNANCNFSM4N3LCHAA .
I see. I agree that the integral and d/d(cosmology) are orthogonal.
Will it suffice to replace numpy
by jax.numpy
in mcfit
?
It's very possible :-D
I am thinking what's a good interface to switch between the numpy and jax backends. Also need to read some jax docs.
Any recommendation? :)
That's a good question. I don't think there is an easy way to do it, or at least not a generic one. One example I know of backend switching between TF and Numpy is here: https://github.com/google/edward2#using-the-numpy-backend Or between TF, NumPy, and JAX: https://github.com/tensorflow/probability/tree/master/tensorflow_probability/python/experimental/substrates But the mechanism seems pretty complicated:
They have a script that rewrite codes on the fly for various backends
Indeed that looks quite complicated. I will look at them more closely. Otherwise it is straightforward to copy and replace numpy by jax, if you don't mind a PR like that ;)
Hi everyone!
I had to do an implementation of the Hankel transform in jax for a project I was working on, perhaps it is useful here as well. You can find it in this repo https://github.com/florpi/JaxHankel
And actually, maybe you can help me understand why does it work since I'm converting jax arrays to numpy arrays and using a scipy function at one point (see here https://github.com/florpi/JaxHankel/blob/main/jax_fht/fht.py#L61) But the final derivatives seem fine
Hey Carolina :-D
That looks super useful indeed! To answer your question rightaway, it will work even if you use scipy functions (because implicitly it will convert back and forth between numpy and jax.numpy arrays) until you try to use jax.jit or jax.grad, and that will fail.
But, it looks like at least gammaln
is implemented in jax already: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.gammaln.html#jax.scipy.special.gammaln
Wouldn't that be enough for your usecase? I think it's doing the same as loggamma(x)
if x is positive.
ah no sorry, I see now that the argument can be imaginary :-| but then it should "just" be a matter of adding an implemenation of loggamma
Hey Carolina :)
I agree with Francois. Maybe there's a way to cache scipy loggamma values in jnp.ndarray's, as long as one does not need derivative to the x values (scales).
Hi both! thanks for your comments :) I also thought that if I was to call grad
or jacobian
it would fail, but you can see that it works on this test https://github.com/florpi/JaxHankel/blob/main/test_jax_fht/test_analytical_cosmology.py#L39 Maybe I'm misunderstanding how jacobian
works?
Regarding loggamma
, the issue I had was exactly the complex number extension hehe I haven't looked too much into it, so it might not be hard. Any ideas are welcome
Ha yes, sorry I had missed that you were getting the jacobian.
Yes, so, this works because the loggamma function is only used to compute coefficients for the Hankel transform. These coefficients are fixed, and you don't take derivatives with respect to them, so no problem during jit compililing or taking gradients.
In practice when you jit compile that function, as long as the arguments are fixed, the scipy code will be called, and the results stored as a "constant" that is then used in the rest of the jax code.
So most likely no need to reimplmeent loggamma :-D
Thank you, that makes a lot of sense! Also, not having to reimplement loggamma makes me very happy :D
In order for jax_cosmo to be used in observational cosmology analyses (e.g. BAO, RSD, fNL) we need a JAX implementation of FFTLog algorithm in order to facilitate Survey Window function convolutions with the Power Spectrum.
This would also be helpful to get models of the correlation function as it was mentioned in another post.
There's already a package that's used very often in cosmology: https://github.com/eelregit/mcfit
It should be possible to implement it using JAX.