kazewong / flowMC

Normalizing-flow enhanced sampling package for probabilistic inference in Jax
https://flowmc.readthedocs.io/en/main/
MIT License
200 stars 23 forks source link

Question regarding the data for the log-likelihood #149

Closed Qazalbash closed 7 months ago

Qazalbash commented 8 months ago

Description

Hi, I am trying to run flowMC to sample from the inhomogeneous Poisson likelihood mentioned in equation 4 of the paper. I am using the following Monte-Carlo approximation for the integral.

WhatsApp Image 2024-03-01 at 4 21 45 PM

You can notice we have a different number of samples for each iteration. JAX does not support such arrays where each sub-array has a different size. Therefore my next natural choice was to use jax.tree_map, but you have vectorized the log_pdf (which is in my case the log-likelihood) by vmap which is not working.

Is there any other way to pass those pre-computed values?

I am attaching the code down here.

from __future__ import annotations

import jax
from jax import jit
from jax import numpy as jnp
from numpyro import distributions as dist

from gwkokab.vts.utils import interpolate_hdf5

from ..models import *
from ..utils.misc import get_key

@jit
def exp_rate(rate, *, pop_params) -> float:
    N = 1 << 13  # 2 ** 12
    lambdas = Wysocki2019MassModel(
        alpha_m=pop_params["alpha_m"],
        k=0,
        mmin=pop_params["mmin"],
        mmax=pop_params["mmax"],
    ).sample(get_key(), sample_shape=(N,))
    I = 0
    m1 = lambdas[..., 0]
    m2 = lambdas[..., 1]
    value = jnp.exp(interpolate_hdf5(m1, m2))
    F = jnp.sum(value)
    I_current = F / N
    I += rate * I_current
    return I

def log_inhomogeneous_poisson_likelihood(x, data=None):
    alpha = x[..., 0]
    mmin = x[..., 1]
    mmax = x[..., 2]
    rate = x[..., 3]
    alpha_prior = dist.LogUniform(-5.0, 5.0)
    mmin_prior = dist.LogUniform(5.0, 30.0)
    mmax_prior = dist.LogUniform(30.0, 100.0)
    # rate_prior = dist.Uniform(1, 500)
    expval = exp_rate(
        rate,
        pop_params={
            "alpha_m": alpha,
            "mmin": mmin,
            "mmax": mmax,
        },
    )
    mass_model = Wysocki2019MassModel(
        alpha_m=alpha,
        k=0,
        mmin=mmin,
        mmax=mmax,
    )
    log_integral = jnp.sum(
        jnp.asarray(
            jax.tree_map(
                lambda d: jax.nn.logsumexp(
                    mass_model.log_prob(d)
                    - alpha_prior.log_prob(alpha)
                    - mmin_prior.log_prob(mmin)
                    - mmax_prior.log_prob(mmax),
                )
                - jnp.log(len(d))
                + jnp.log(rate),
                data,
            )
        )
    )
    return log_integral - expval

Wysocki2019MassModel is equation 7 of the same paper mentioned above. Also, I have an inner feeling that there is something wrong with the likelihood function too, but I cannot point it out.

kazewong commented 8 months ago

@Qazalbash Regarding the sub-arrays being different sizes, you mean the number of samples per event, d, being different sizes from event to event, right? Usually, the easiest way is to upsample or downsample the event to a common size, then one should be able to vmap over it without problems.

Regarding your likelihood, a number of things:

  1. It seems your exp_rate and likelihood function is constructing the Wysocki2019MassModel at every call, which I am not sure whether it is intended. Depending on how the Wysocki2019MassModel is written, this can massively slow the computation down. Same goes for the priors. I would initialize those objects outside the likelihood and pass them in as data.
  2. The function signature of the likelihood should be f(x: array, data: dict), where x is a 1D array of the parameters. Meaning it should something like x = jnp.array([alpha, mmin, mmax, rate]).
  3. The sample in the mass model seems to be a stochastic process to me, which I am not sure you need it. I assume lambdas are the hyperparameters, and it is for computing the rate. In that case, x in the likelihood is your lambda, you don't have to resample it within the likelihood.
Qazalbash commented 8 months ago

@kazewong Thank you for the points. I have a few uncertainties about certain aspects of the process, which I've outlined below. I hope this doesn't inconvenience you.

  1. The reason for making Wysocki2019MassModel inside the log_inhomogeneous_poisson_likelihood is according to my understanding, we are running the flowMC to recover the parameters $\alpha$, $m\text{min}$, $m\text{max}$, $\mathcal{R}$, therefore at each iteration, these values would change as we can see that we are passing them as x in the log_inhomogeneous_poisson_likelihood, therefore our model should also get updated according to these parameters. This is my understanding and I highly doubt it.
  2. I am following one of the tutorial in the docs and there they have passed data as a jnp.array. I also can not grasp how we are passing different data to the sampler in the form of dict.

This is the Wysocki2019MassModel.

```python # Copyright 2023 The GWKokab Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing_extensions import Optional from jax import lax, numpy as jnp from jax.random import uniform from jaxtyping import Array from numpyro.distributions import constraints, Distribution from numpyro.distributions.util import promote_shapes, validate_sample from ..utils.misc import get_key class Wysocki2019MassModel(Distribution): r"""It is a double-side truncated power law distribution, as described in equation 7 of the `paper `__. .. math:: p(m_1,m_2\mid\alpha,k,m_{\text{min}},m_{\text{max}},M_{\text{max}})\propto\frac{m_1^{-\alpha-k}m_2^k}{m_1-m_{\text{min}}} """ arg_constraints = { "alpha_m": constraints.real, "k": constraints.nonnegative_integer, "mmin": constraints.positive, "mmax": constraints.positive, } support = constraints.real_vector reparametrized_params = ["m1", "m2"] def __init__(self, alpha_m: float, k: int, mmin: float, mmax: float, *, valid_args=None) -> None: r"""Initialize the power law distribution with a lower and upper mass limit. :param alpha_m: index of the power law distribution :param k: mass ratio power law index :param mmin: lower mass limit :param mmax: upper mass limit :param valid_args: If `True`, validate the input arguments. """ self.alpha_m, self.k, self.mmin, self.mmax = promote_shapes(alpha_m, k, mmin, mmax) batch_shape = lax.broadcast_shapes( jnp.shape(alpha_m), jnp.shape(k), jnp.shape(mmin), jnp.shape(mmax), ) super(Wysocki2019MassModel, self).__init__( batch_shape=batch_shape, validate_args=valid_args, event_shape=(2,), ) @validate_sample def log_prob(self, value): return ( -(self.alpha_m + self.k) * jnp.log(value[..., 0]) + self.k * jnp.log(value[..., 1]) - jnp.log(value[..., 0] - self.mmin) ) def sample(self, key: Optional[Array | int], sample_shape: tuple = ()) -> Array: if key is None or isinstance(key, int): key = get_key(key) m2 = uniform(key=key, minval=self.mmin, maxval=self.mmax, shape=sample_shape + self.batch_shape) U = uniform(key=get_key(key), minval=0.0, maxval=1.0, shape=sample_shape + self.batch_shape) beta = 1 - (self.k + self.alpha_m) conditions = [beta == 0.0, beta != 0.0] choices = [ jnp.exp(U * jnp.log(self.mmax) + (1.0 - U) * jnp.log(m2)), jnp.exp(jnp.power(beta, -1.0) * jnp.log(U * jnp.power(self.mmax, beta) + (1.0 - U) * jnp.power(m2, beta))), ] m1 = jnp.select(conditions, choices) return jnp.stack([m1, m2], axis=-1) def __repr__(self) -> str: string = f"Wysocki2019MassModel(alpha_m={self.alpha_m}, k={self.k}, " string += f"mmin={self.mmin}, mmax={self.mmax})" return string ```

I have pretty much got the idea where I am wrong. Your response will make it more clear. Thanks in advance.

kazewong commented 7 months ago

I think this is resolved. Closing the issue.