LSSTDESC / bayesian-pipelines-cosmology

Bayesian Cosmological Inference from Tomographic Maps
MIT License
4 stars 1 forks source link

Matter power spectrum as function of cosmology (most likely emulator, or E.H.) #2

Open EiffL opened 2 years ago

EiffL commented 2 years ago

This issue is to discuss the particular choice for the matter power spectrum that will be an input to all forward pipelines.

Here are a question to get us started: @Supranta @nataliaporqueres : What are your options in terms of differentiable matter power spectra?

With @jecampagne and @dlanzieri we are still just using Eisenstein & Hu in jax-cosmo and flowpm respectively, but we can adapt to whatever other solution you are currently using.

nataliaporqueres commented 2 years ago

In BORG, we have Eisenstein-Hu and CLASS.

EiffL commented 2 years ago

Ok great, thanks @nataliaporqueres :-) but I guess you can't take the derivatives of CLASS, do you have a CLASS emulator? or do you use some trick to get the derivatives by finite differences from CLASS?

nataliaporqueres commented 2 years ago

We only need the derivative with respect to the initial conditions delta^{IC} because the HMC is only used to sample the density field. We use a slice sampler for the cosmological parameters, which doesn't involve derivatives.

EiffL commented 2 years ago

Ah! Interesting! I hadn't realized that, sorry. Very interesting :-)

Supranta commented 2 years ago

I have a PCE emulator as well as a GP emulator for power spectrum. I am currently adapting my GP emulator to work with tinygp, which is based on jax. So I can add it here.

jecampagne commented 2 years ago

I have setup this Jax-CLASS emulator based on the original work by the authors of https://arxiv.org/pdf/2105.02256.pdf

EiffL commented 2 years ago

This is great @jecampagne :-D Only thing is your repo is private, people will probably not be able to see the content, would you be ok with making it public?

jecampagne commented 2 years ago

Proposal fro new set of cosmologies: I propose the following primary parameters (1000 latin HC points) ranges ‘Omega_cdm’: [0.1, 0.5] ‘Omega_b’: [0.04, 0.06] ‘sigma8’: [0.7, 0.9] ‘n_s’: [0.87, 1.06] ‘h’: [0.55, 0.8] below are the distributions of these variables (first line), with derived parameters (second line). The red vertical line is the fiducial values. image

And I will restrict the redshift range into [0, 3.5], for "k" the range is [5e-4, 50] h/Mpc by the high bound will be closely link to dP/P accuracy (ie. relative difference wrt CLASS computation).

jecampagne commented 2 years ago

Here a result with the new training sets and the retrianing of the Models with the new scheme of building blocs, and taking care of the CLASS k_max_h_by_Mpc parameter which influences the Pk computations stability. The figure below is using the Planck15 Jax-cosmo parameters which is not a training cosmo for the emulator image

Comparaison with jax-cosmo for the jacobian and the vmap, notice that at the scale of the figure one cannot appreciate the diffrences at the level of few %. image image

jecampagne commented 2 years ago

In fact the 1D and 2D interpolation in the (k,z) grid is performed (and constrained) to be done over a regular grid in (log scale for the k, and linear scale for the redshift z) and ideally I would sample differently to get better accuracy in the interval k=[0.1-1] h/Mpc. But the bottleneck is certainly the cosmo param sampling at the end of the day.

jecampagne commented 2 years ago

I would like to share something that is quite stupid (sorry this is not great ML Bayesian theory): it concerns the time spent to load the parameters of the emulators.

In the current version, I need to load : 2x120 (PkLin and PkNLin at z=0) + 2x120x20 (the fucntions P(k,z)/P(k,z=0) both Linear and Non Linear) ~ 5200 GPs. The total amount of data is 288MB (not a big deal, isn't it)

Now depending where are located the files containing the parameters and the link Bandwidth, it can take (here at the Computing center at Lyon, if the data are in the system cache

Growth end-start (sec) 10.849421180784702
Pk Lin end-start (sec) 0.5452490821480751
Pk Non Lin end-start (sec) 0.541755635291338
Q func end-start (sec) 11.04175541549921

but it can takes 10 times more if the data are not in the system cache. End in Google Collab, the data are on the "MyDrive" and it has taken 1h to be loaded on a notebook! But two days after it has taken 30sec...

Does someone can help to try to speed this boring process?

Notice that after performing the computation of the Linear & Non linear Pk

pk_linear_interp = emu.linear_pk(cosmo_jax, k_star,z_star)
pk_nonlin_interp = emu.nonlinear_pk(cosmo_jax,k_star, z_star)

with 1200 k-values (k_star) for 4 differents redshits (z_start) the XLA times can vary abd last 10min. After the execution is about 15-30ms.

jecampagne commented 2 years ago

I have refactoring the code (not yet public) which allow a jitted version of

k_star = jnp.geomspace(5e-4, 50, 1200, endpoint=True) #h/Mpc
z_star = jnp.array([0.,1., 2., 3.])
pk_linear_interp = emu.linear_pk(cosmo_jax, k_star,z_star)

The XLA compilation takes 400sec while the execution after takes 25 ms. Now, the XLA may perhaps be speed up using more JAX idioms.

jecampagne commented 2 years ago

I succeed to speed up the data loading thanks to parallelized code. It takes 25sec @ CC and Google Colab. Of course if the bandwidth is greater then this overhead can be reducted. Concerning JAX/JIT I'm making an exhaustive code review and XLA tweeking. But, the work is not yet satisfactory. If no JITization done the emulator is althought faster than jax-cosmo. If JIT is turn on, the XLA compilation can take several minutes, but then the emulator is at the level of millisec.

jecampagne commented 2 years ago

With the help of XLA experts, I proceed to a code review to use PyTree as well as JIT with parsimony. I manage to get a XLA compilation ~8sec and an execution at the level of 1.5sec for 1200x4 (k,z) evaluations both for linear and non linear Pk. The emulator is now faster than the current jax_cosmo lib for the same types of computations (eg. vmap & jacfwd). The Git repos has been updated accordingly.

jecampagne commented 2 years ago

I finally succeed to boost the prediction speed at a moderate price for XLA compilation time. So, it yields with the new committed version of Jemu

jecampagne commented 2 years ago

I write some scripts to