hsimonfroy / montecosmo

Differentiable pipeline for field-level cosmological inference from galaxy surveys.
MIT License
2 stars 0 forks source link

Model Explicit Likelihood Inference #3

Open hsimonfroy opened 6 months ago

hsimonfroy commented 6 months ago

Context

Documenting field-level explicit likelihood inference from a differentiable cosmological model.

In code, we run joint inferences of the initial field ($64^3$ mesh), cosmological parameters ($\Omega_c$ and $\sigma_8$), and Lagrangian bias parameters.

For one-chain samplers, chains are initialized on fiducial.

In order to assess chain convergences, so to observe a potential bias in the sampling process, we should limit the other sources of bias. For instance, observations should be generated with no likelihood noise. This doesn't solve that all, cf. a simple $\mathcal N(x|z^2,I)$ likelihood model. In any case, samplers will be compared with respect to a reference posterior obtained by an Implicit Likelihood Inference method.

Aims

Samplers to test:

hsimonfroy commented 6 months ago

NUTS 🥜

NUTS max_tree_depth comparison.

In the tested range of parameters, Hamiltonian trajectory lengths are almost always the maximal 2**max_tree_depth-1, i.e. U-Turning is never reached.

num_samples are adapted to max_tree_depth such that each run required about 300,000 (+60,000 warmup) Hamiltonian steps, 3h wall time. Put differently, num_samples * (2**max_tree_depth-1)$\simeq$ 300,000.

image

Summary Tables ### NUTS, mtd=12 using 75 rows, 6 parameters; mean weight 1.0, tot weight 75.0 | | mean | std | median | 5.0% | 95.0% | n_eff | r_hat | |---------|-------|------|--------|-------|-------|-------|-------| | Omega_c | 0.25 | 0.01 | 0.25 | 0.24 | 0.26 | 28.92 | 1.02 | | b1 | 0.98 | 0.03 | 0.98 | 0.94 | 1.03 | 11.48 | 1.11 | | b2 | -0.03 | 0.01 | -0.03 | -0.05 | -0.01 | 16.63 | 1.07 | | bnl | 0.32 | 0.24 | 0.32 | -0.05 | 0.68 | 34.93 | 1.06 | | bs | -0.01 | 0.04 | -0.01 | -0.08 | 0.07 | 28.06 | 1.18 | | sigma8 | 0.84 | 0.01 | 0.84 | 0.82 | 0.86 | 14.29 | 1.02 | ### NUTS, mtd=10 using 300 rows, 6 parameters; mean weight 1.0, tot weight 300.0 | | mean | std | median | 5.0% | 95.0% | n_eff | r_hat | |---------|-------|------|--------|-------|-------|-------|-------| | Omega_c | 0.25 | 0.01 | 0.24 | 0.23 | 0.26 | 60.32 | 1.04 | | b1 | 0.98 | 0.03 | 0.98 | 0.93 | 1.05 | 6.95 | 1.00 | | b2 | -0.03 | 0.02 | -0.03 | -0.06 | 0.00 | 6.78 | 1.03 | | bnl | 0.39 | 0.30 | 0.40 | -0.10 | 0.87 | 36.07 | 1.02 | | bs | -0.00 | 0.05 | -0.00 | -0.08 | 0.07 | 9.04 | 1.05 | | sigma8 | 0.83 | 0.02 | 0.83 | 0.80 | 0.86 | 6.56 | 1.01 | ### NUTS, mtd=8 using 1215 rows, 6 parameters; mean weight 1.0, tot weight 1215.0 | | mean | std | median | 5.0% | 95.0% | n_eff | r_hat | |---------|-------|------|--------|-------|-------|-------|-------| | Omega_c | 0.25 | 0.01 | 0.25 | 0.23 | 0.26 | 41.84 | 1.00 | | b1 | 0.99 | 0.02 | 0.99 | 0.96 | 1.01 | 16.98 | 1.01 | | b2 | -0.03 | 0.01 | -0.03 | -0.05 | -0.01 | 18.18 | 1.03 | | bnl | 0.38 | 0.23 | 0.37 | -0.01 | 0.72 | 14.33 | 1.09 | | bs | 0.02 | 0.04 | 0.01 | -0.04 | 0.08 | 15.78 | 1.11 | | sigma8 | 0.83 | 0.01 | 0.83 | 0.82 | 0.84 | 14.88 | 1.03 | ### NUTS, mtd=3 using 44290 rows, 6 parameters; mean weight 1.0, tot weight 44290.0 | | mean | std | median | 5.0% | 95.0% | n_eff | r_hat | |---------|-------|------|--------|-------|-------|-------|-------| | Omega_c | 0.25 | 0.01 | 0.25 | 0.24 | 0.26 | 8.39 | 1.02 | | b1 | 1.03 | 0.01 | 1.03 | 1.01 | 1.05 | 15.61 | 1.03 | | b2 | -0.01 | 0.01 | -0.01 | -0.03 | 0.01 | 4.96 | 1.28 | | bnl | 0.17 | 0.26 | 0.15 | -0.24 | 0.60 | 5.79 | 1.48 | | bs | 0.02 | 0.04 | 0.03 | -0.04 | 0.09 | 3.56 | 1.73 | | sigma8 | 0.82 | 0.01 | 0.82 | 0.81 | 0.83 | 7.73 | 1.00 |

NUTS benchmarking

In the NumPyro's implementation of NUTS, number of evaluations of the model, its logprob, or its score, can be easily accessed. For step i, NUTS makes extra_fields['num_steps'][i] calls to value_and_grad(), which relies on vjp. JAX documentation states that

The FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$.

Moreover, if the model contains deterministic variables, model may have to be replayed once for each sample, to evaluate these variables based on the samples. In NumPyro's, samples are postprocessed, and the replay_model value is decided here.

Therefore, for a given NUTS run, the total number of model evaluations would be in the order of 3*extra_fields['num_steps'].sum() if model does not contain deterministic variables, (3*extra_fields['num_steps']+1).sum() otherwise.

Long run

temp2

Summary Table ### NUTS, mtd=10 using 1560 rows, 6 parameters; mean weight 1.0, tot weight 1560.0 | | mean | std | median | 5.0% | 95.0% | n_eff | r_hat | |---------|-------|------|--------|-------|-------|--------|-------| | Omega_c | 0.25 | 0.01 | 0.24 | 0.23 | 0.26 | 245.96 | 1.00 | | b1 | 0.99 | 0.03 | 0.99 | 0.94 | 1.04 | 65.24 | 1.01 | | b2 | -0.03 | 0.01 | -0.03 | -0.05 | -0.00 | 83.45 | 1.00 | | bn2 | 0.37 | 0.29 | 0.39 | -0.12 | 0.84 | 180.83 | 1.00 | | bs2 | 0.01 | 0.05 | 0.01 | -0.07 | 0.09 | 96.91 | 1.01 | | sigma8 | 0.83 | 0.01 | 0.83 | 0.81 | 0.85 | 58.27 | 1.01 |
hsimonfroy commented 5 months ago

Posterior Analysis

Some comments on posterior samples, here 1560 samples obtained by NUTS (mtd=10).

post_pk

NUTS_mtd10_1560_meshtriangle2

NUTS_ns60_mtd10_invMvar

hsimonfroy commented 5 months ago

Reparametrization comparison

Comparing between standardly and unstandardly parametrized models, sampled with NUTS (max_tree_depth=10).

NUTS_mtd10_2048_standunstand

NUTS_mtd10_8192_standunstand

Summary Tables ### NUTS, mtd=10, unstandard param using 8192 rows, 6 parameters; mean weight 1.0, tot weight 8192.0 (8 chains grouped by) total n_evals: 8380416 (=(2**10-1)*8192) | | mean | std | median | 5.0% | 95.0% | n_eff | r_hat | |---------|-------|------|--------|-------|-------|---------|-------| | Omega_c | 0.25 | 0.01 | 0.25 | 0.23 | 0.26 | 1573.10 | 1.00 | | b1 | 0.99 | 0.03 | 0.99 | 0.94 | 1.04 | 178.29 | 1.03 | | b2 | -0.03 | 0.01 | -0.03 | -0.05 | -0.01 | 324.02 | 1.01 | | bn2 | 0.37 | 0.29 | 0.37 | -0.11 | 0.85 | 723.79 | 1.01 | | bs2 | 0.02 | 0.05 | 0.02 | -0.06 | 0.10 | 340.57 | 1.01 | | sigma8 | 0.83 | 0.01 | 0.83 | 0.81 | 0.85 | 207.84 | 1.02 | ### NUTS, mtd=10, standard param using 8192 rows, 6 parameters; mean weight 1.0, tot weight 8192.0 (8 chains grouped by) total n_evals: 8380416 (=(2**10-1)*8192) | | mean | std | median | 5.0% | 95.0% | n_eff | r_hat | |---------|-------|------|--------|-------|-------|---------|-------| | Omega_c | 0.25 | 0.01 | 0.25 | 0.23 | 0.26 | 1285.74 | 1.00 | | b1 | 0.99 | 0.03 | 0.99 | 0.95 | 1.04 | 213.41 | 1.02 | | b2 | -0.03 | 0.01 | -0.03 | -0.05 | -0.01 | 399.69 | 1.01 | | bn2 | 0.36 | 0.29 | 0.36 | -0.09 | 0.85 | 816.97 | 1.01 | | bs2 | 0.02 | 0.05 | 0.02 | -0.07 | 0.10 | 378.71 | 1.01 | | sigma8 | 0.83 | 0.01 | 0.83 | 0.81 | 0.85 | 245.57 | 1.02 |

Theory

This is consistent with theory. Likelihood acts as a filter of the prior information. Standard parametrization makes the prior better conditioned, but depending on the likelihood, this can make the posterior better or worse conditioned.

standunstand_param

(example from code)

See Betancourt blog for a another point of view, though it doesn't show differences between strong aligned and strong opposing likelihood.