andrejobuljen / Hi-Fi_mocks

Codes to generate fast HI field-level mocks in real and redshift space
GNU General Public License v3.0
5 stars 0 forks source link

Differentiable version of Hi-Fi Mocks #1

Open EiffL opened 1 year ago

EiffL commented 1 year ago

Here the beginning of an example about how to implement this using JAX: https://colab.research.google.com/drive/1uneGJL4ewmV-Oyn-9k1H22gQnFPn_yB2?usp=sharing

EiffL commented 1 year ago

And just for completeness, here is the notebook that can run HMC on a lognormal model: https://colab.research.google.com/drive/1U6HymNm0mJD-Kj07YkN7FLFfp_mZCfGW?usp=sharing

EiffL commented 1 year ago

@andrejobuljen I had a little bit of time on the train and tried to code up a full example... with moderate success ^^

image

The code is in this new notebook

I only implemented up to dG2, not d3, I don't know how much that matters.

And I saw that you had an orthogonalization step that I didn't implement because we didn't discuss it earlier today, probably that matters?

And finally I am not 100% sure of the conversion from my unitless conventions of scales to h/Mpc needed in to compute the filters (I'm like 80% sure). So if everything looks ok but the results still don't make sense it could come from there.

andrejobuljen commented 1 year ago

@EiffL great thanks!

Yes, sorry, I also realised later I forgot to mention the orthogonalisation step when we discussed yesterday. For that step I'd only need to measure various cross-Pk between fields. I tried running auto-pk but run into the following problem:

Do you maybe have somewhere a cross-pk function written in jax?

Also thanks for fixing the weights in cic_paint! I did something similar yesterday adding: if weight is not None: kernel = kernel * weight[...,jnp.newaxis] below this line: kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2], but I will just use your new version.

Btw, do you know how to match the jax seed to numpy.random.seed when generating initial conditions? In nbodykit it was enough to use the same random seed that was used in TNG simulation and I'd get same IC/phases. Here, in jax, it seems like it doesn't work... i.e. when I use the same seed I get a different field. This may not be super important, it would just allow to test the code quicker...

Cheers, Andrej