Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
135 stars 9 forks source link

Differential information gain biased with phantom samples #123

Open Joshuaalbert opened 7 months ago

Joshuaalbert commented 7 months ago

Describe the bug

-H = KL(L(x)p(x)/Z || p(x)) = int L(x)p(x)/Z log(L(x) p(x)/Z/p(x)) = -E[log(L) - log(Z)]

-H = jnp.sum(dp_mean * log_L_samples) - log_Z_mean

But, I'm seeing -H be too large when more phantom points are retained. It should be similar to the amount of compression needs to reach the typical set. Similarly I saw a scaling factor in ESS and logZ_uncert of k+1. It might be that log_dp_mean should be divided by k+1 to solve both problems. Something about the actual number of effective samples vs how shrinkage is computed. Needs to be investigated.

JAXNS version 2.3.x, 2.4.0