Open EiffL opened 1 year ago
And just for completeness, here is the notebook that can run HMC on a lognormal model: https://colab.research.google.com/drive/1U6HymNm0mJD-Kj07YkN7FLFfp_mZCfGW?usp=sharing
@andrejobuljen I had a little bit of time on the train and tried to code up a full example... with moderate success ^^
The code is in this new notebook
I only implemented up to dG2, not d3, I don't know how much that matters.
And I saw that you had an orthogonalization step that I didn't implement because we didn't discuss it earlier today, probably that matters?
And finally I am not 100% sure of the conversion from my unitless conventions of scales to h/Mpc needed in to compute the filters (I'm like 80% sure). So if everything looks ok but the results still don't make sense it could come from there.
@EiffL great thanks!
Yes, sorry, I also realised later I forgot to mention the orthogonalisation step when we discussed yesterday. For that step I'd only need to measure various cross-Pk between fields. I tried running auto-pk but run into the following problem:
If I run: p1 = power_spectrum(shifted1, kmin=jnp.pi/256, dk=2.*jnp.pi/256, boxsize=box_size[0])
, I get this error due to boxsize :"TypeError: 'float' object is not iterable".
If I run: p1 = power_spectrum(shifted1, kmin=jnp.pi/256, dk=2.*jnp.pi/256, boxsize=box_size)
, I get this error "AttributeError: 'list' object has no attribute 'prod'", again due to boxsize.
So what I did is to substitute P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
with P = ((Psum / Nsum)[1:-1] * jnp.array(boxsize).prod()).astype('float32')
, run it with boxsize=box_size
and it worked.
Do you maybe have somewhere a cross-pk function written in jax?
Also thanks for fixing the weights in cic_paint! I did something similar yesterday adding:
if weight is not None: kernel = kernel * weight[...,jnp.newaxis]
below this line:
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
, but I will just use your new version.
Btw, do you know how to match the jax seed to numpy.random.seed when generating initial conditions? In nbodykit it was enough to use the same random seed that was used in TNG simulation and I'd get same IC/phases. Here, in jax, it seems like it doesn't work... i.e. when I use the same seed I get a different field. This may not be super important, it would just allow to test the code quicker...
Cheers, Andrej
Here the beginning of an example about how to implement this using JAX: https://colab.research.google.com/drive/1uneGJL4ewmV-Oyn-9k1H22gQnFPn_yB2?usp=sharing