Open EiffL opened 2 years ago
My first prototype model using the DC2 n(z) provided by @elts6570 in #5 , a low resolutiton PM nbody, a E&H power spectrum, and Born convergence can be found here
Here is what the model looks like:
def forward_model():
"""
This function defines the top-level forward model for our observations
"""
# Sampling cosmological parameters and defines cosmology
Omega_c = numpyro.sample('Omega_c', dist.Uniform(0.1, 0.9))
sigma8 = numpyro.sample('sigma8', dist.Uniform(0.4, 1.0))
Omega_b = numpyro.sample('Omega_b', dist.Uniform(0.03, 0.07))
h = numpyro.sample('h', dist.Uniform(0.55, 0.91))
n_s = numpyro.sample('n_s', dist.Uniform(0.87, 1.07))
w0 = numpyro.sample('w0', dist.Uniform(-2.0, -0.33))
cosmo = jc.Cosmology(Omega_c=Omega_c, sigma8=sigma8, Omega_b=Omega_b,
h=h, n_s=n_s, w0=w0, Omega_k=0., wa=0.)
# Generate lightcone density planes through an nbody
density_planes = nbody(cosmo)
# Create photoz systematics parameters, and create derived nz
nzs_s_sys = [jc.redshift.systematic_shift(nzi,
numpyro.sample('dz%d'%i, dist.Normal(0., 0.01)),
zmax=2.5)
for i, nzi in enumerate(nz_shear)]
# Generate convergence maps by integrating over nz and source planes
convergence_maps = [simps(lambda z: nz(z).reshape([-1,1,1]) * kappa(cosmo, density_planes, z), 0., 2.5, N=64)
for nz in nzs_s_sys]
# Apply noise to the maps (this defines the likelihood)
observed_maps = [numpyro.sample('kappa_%d'%i,
dist.Normal(k, sigma_e/jnp.sqrt(10**2*galaxy_density))) # assumes pixel size of 10 arcmin
for i,k in enumerate(convergence_maps)]
return observed_maps
It uses these true n(z) for source redshift:
and generates corresponding tomographic convergence maps:
I'm going to commit the code in a branch for review.
I do not see where the 'fn' are generated in
trace['kappa_%d'%i]['fn']
They are generated there:
numpyro.sample('kappa_%d'%i,
dist.Normal(k, sigma_e/jnp.sqrt(10**2*galaxy_density))) #
ha Yes! the trace
return a dictionary with the 'fn' key word, as for instance (nothing to do with the present use-case)
OrderedDict([('a',
{'args': (),
'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>,
'is_observed': False,
'kwargs': {'rng_key': DeviceArray([0, 0], dtype=uint32)},
'name': 'a',
'type': 'sample',
'value': DeviceArray(-0.20584235, dtype=float32)})])
This notebook runs an HMC on a simplified full field inference problem: https://colab.research.google.com/drive/1_2VbA1a3sSrxzdPjt1c-QGe7NmK5vLh_?usp=sharing
We can use this issue to discuss how we want to build our framework for describing the forward model.
I'm going to propose a framework based on JAX and numpyro as the PPL, which we can then discuss :-)