Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
141 stars 9 forks source link

Request for Special Priors #91

Closed Gaurav17Joshi closed 1 year ago

Gaurav17Joshi commented 1 year ago

Is your feature request related to a problem? Please describe. Hi, I am working on a project with the stingray library in which we are making a feature to evaluate astronomical time series using gaussian processes with jaxns for Bayesian inference and evidence sampling.

For this, we want to use some multivariable priors. We want to create a mean function as the sum of multiple functions (say gaussian/exponential).

Screenshot 2023-06-18 at 9 25 38 PM

Here n is the number of gaussians and $A_i, t_i, \sigma_i$ are its parameters.

We want to build two kinds of multivariate distributions for these mean parameters for our jaxns prior_model.

  1. Independent Multivariate Uniform Prior: For the parameter A, sigma, we have the upper and lower bounds and we want to uniformly sample n A's.
  2. Constrained Multivariate Beta Prior: For the parameter $t_i$, we need a constrained prior which samples out times such that $ti < t{i+1}$. Eg. for a 0 to 20s lightcurve with n = 3, we will have to sample the three maxtimes such that $0 < t{0} < t{1}< t{2} < 20$. We have used the Forced_Indetifiability, from jaxns and it works well, but in the reference paper, the authors have used a constrained beta distribution. Ie $t_0$ is a beta(alpha =1, beta = n = 3) distribution from 0 to 20s, $t_1$ is a beta(1,n-1 = 2 ) from $t_0$ to 20s, $t_2$ is a beta(1, n-2 = 1) distribution from $t_1$ to 20, ie a conditional beta where alpha remains 1, and beta decreases from n to 1. We want to implement that also.

Describe the solution you'd like It would be very helpful if these two special priors could be included in the jaxns special prior section. I would want to help make these features if you can share some references as to how these special priors are made

Describe alternatives you've considered I tried to make tfpd joint distributions for these but as it did not work in jaxns as tfpd joint distributions do not have a _quatile function.

Joshuaalbert commented 1 year ago

Hi @Gaurav17Joshi, let me see if I understand correctly. You want to use multi-variate uniform distribution for A_i, and sigma_i, and for t_i you want to use a scaled and shifted Beta(1, n-i)? Note, that for t_i generated this way they are not forced to be sorted, and thus identifiable, so you will likely have degeneracies in your posterior. Also, the mean of t_i by the above definition would be 1/(1 + n - i), which goes from large to small as i increases, and seems opposite from what you want.

You can do all this without special priors. Note, the casting of constants to jnp.asarray(..., float_type), this is good practice if you'll be using 64-bit JAX, which I assume you are with GPs.

I'll assume some data for illustration.

from jaxns.types import float_type

# Your hyperparameters for priors, when using 64-bit it's important to cast things appropriately
n = 3
t_max = jnp.asarray(20., float_type)
t_min = jnp.asarray(0., float_type)
A_lower = jnp.asarray(0., float_type)
A_upper = jnp.asarray(1., float_type)
sigma_lower = jnp.asarray(0., float_type)
sigma_upper = jnp.asarray(1., float_type)

# Make fake data
X = jnp.linspace(t_min, t_max, 5)
Y = jnp.exp(jnp.sin(X))

Then define the prior model,

def prior_model() -> PriorModelGen:
    A = yield Prior(tfpd.Uniform(low=A_lower * jnp.ones(n), high=A_upper * jnp.ones(n)), name='A')
    sigma = yield Prior(tfpd.Uniform(low=sigma_lower * jnp.ones(n), high=sigma_upper * jnp.ones(n)), name='sigma')
    t_array = []
    scale_bij = tfp.bijectors.Scale(scale=t_max - t_min)
    shift_bij = tfp.bijectors.Shift(shift=t_min)
    for i in range(n):
        underlying_beta = tfpd.Beta(
            concentration1=jnp.asarray(1., float_type),
            concentration0=jnp.asarray(n - i, float_type)
        )
        t = yield Prior(shift_bij(scale_bij(underlying_beta)), name=f"t{i}")
        t_array.append(t)
    t_array = jnp.stack(t_array)
    return A, sigma, t_array

Finally, the likelihood, and let's test the model

def log_likelihood(A, sigma, t_array):
    @vmap
    def eval_mean(x):
        dx = (t_array - x) / sigma
        components = A * jnp.exp(-0.5 * jnp.square(dx))
        return jnp.sum(components)

    m_X = eval_mean(X)
    # Do something with m_X
    return -jnp.sum(jnp.square(Y - m_X))

model = Model(prior_model=prior_model,
              log_likelihood=log_likelihood)

model.sanity_check(random.PRNGKey(0), S=100)

# Example prior sample
print(model.transform(model.sample_U(random.PRNGKey(42))))
# {'A': Array([0.26283673, 0.10945365, 0.46926031], dtype=float64), 'sigma': Array([0.40959993, 0.08672034, 0.48140902], dtype=float64), 't0': Array(15.24431331, dtype=float64), 't1': Array(19.5986375, dtype=float64), 't2': Array(0.76363376, dtype=float64)}

Does this meet your need?

Gaurav17Joshi commented 1 year ago

Hi @Joshuaalbert , thanks for the prompt reply The first issue has been resolved for the A, sigma by your code. It fits my need and is working in my code.

As for the second one, the Constrained Multivariate Beta Prior, is indeed shifted and scaled, but it is also conditioned. The shift and scale factor is not same for all, for $ti$, the shift is $t{i-1}$ and scale is $t{max} - t{i-1}$, instead of having all as shifted and scaled by same factor. This makes sure that the successive times are in increasing order $t_{min} < t_0 < t_1 .... tn < t{max}$. I have written a tfpd joint distribution code for n = 3, to make it more clear:-

jointds = tfd.JointDistributionSequential([
    tfb.Shift(t_min)( tfb.Scale(t_max) (tfd.Beta(1,3)) ),                       # t_0
    lambda t_0: tfb.Shift(t_0)( tfb.Scale(t_max-t_0) (tfd.Beta(1,2)) ),         # t_1 
    lambda t_1: tfb.Shift(t_1)( tfb.Scale(t_max-t_1) (tfd.Beta(1,1)) ),         # t_2 
])

Avoiding degeneracy in the times is central to this project. As I said the ForcedIdentifiabilty Special prior does an excellent job of providing non degenerate samples following $t_{min} < t_0 < t_1 .... tn < t{max}$ just that it gives us uniform samples. I just need a similar prior which samples it through a beta function

An image for the constrained beta pdf:

Screenshot 2023-06-19 at 10 46 27 PM
Joshuaalbert commented 1 year ago

Ah, I see. I misread your original post, which I why I pointed out the degeneracy. But, it's very easy to incorporate into the code if n is reasonably small.

    t_array = []
    scale_bij = tfp.bijectors.Scale(scale=t_max - t_min)
    shift_bij = tfp.bijectors.Shift(shift=t_min)
    for i in range(n):
        underlying_beta = tfpd.Beta(
            concentration1=jnp.asarray(1., float_type),
            concentration0=jnp.asarray(n - i, float_type)
        )
        t = yield Prior(shift_bij(scale_bij(underlying_beta)), name=f"t{i}")
        # Update the shift and scale here
        scale_bij = tfp.bijectors.Scale(scale=t_max - t)
        shift_bij = tfp.bijectors.Shift(shift=t)
        t_array.append(t)
    t_array = jnp.stack(t_array)
Gaurav17Joshi commented 1 year ago

Thanks, this is working well in my code. Just one more request, do you have some ideas as to how one may prepare tests for such priors and the prior_model in general.

Joshuaalbert commented 1 year ago

What aspect are you looking to test?

Joshuaalbert commented 1 year ago

@Gaurav17Joshi would love to help you further with testing prior, but can you please open a thread in the discussions? I'll close this given that we've addressed the initial issue.

https://github.com/Joshuaalbert/jaxns/discussions