kazewong / jim

Gravitational-wave data analysis tools in Jax
MIT License
59 stars 19 forks source link

Prior uniform in m1-m2 space with a bound in chirp mass and mass ratio #164

Closed tsunhopang closed 1 month ago

tsunhopang commented 1 month ago

This PR is to have the prior defined as uniform in component mass, while the bound is defined by chirp mass and mass ratio. Although using the bound-to-unbound in chirp mass and mass ratio space and the current initialization procedure would solve the sampling side problem, the evidence estimation will be off due to a constant shift in the log posterior. This PR is to fix this issue.

The following test script is used, with the scatter plot of the samples shown.

test_m1_m2

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior
from jimgw.single_event.likelihood import ZeroLikelihood
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import UniformInComponentMassSecondaryMassTransform
from jimgw.single_event.utils import Mc_m1_to_m2
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)

m1_prior = UniformPrior(6.0, 53.0, parameter_names=['m_1'])
m2_q_prior = UniformPrior(0.0, 1.0, parameter_names=['m_2_quantile'])

prior = CombinePrior(
    [
        m1_prior,
        m2_q_prior,
    ]
)

sample_transforms = [
    # all the bound-to-unbound transform
    BoundToUnbound(
        name_mapping = (["m_1"], ["m_1_unbounded"]),
        original_lower_bound=m1_prior.xmin, original_upper_bound=m1_prior.xmax
    ),
    BoundToUnbound(
        name_mapping = (["m_2_quantile"], ["m_2_quantile_unbounded"]),
        original_lower_bound=m2_q_prior.xmin, original_upper_bound=m2_q_prior.xmax
    ),
]

likelihood_transforms = [
    UniformInComponentMassSecondaryMassTransform(
        q_min=0.125, q_max=1.0,
        M_c_min=5.0, M_c_max=15.0,
        m_1_min=m1_prior.xmin,
        m_1_max=m1_prior.xmax
    ),
]

likelihood = ZeroLikelihood()

mass_matrix = jnp.eye(len(prior.base_prior))
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1)

n_epochs = 2
n_loop_training = 1
learning_rate = 1e-4

jim = Jim(
    likelihood,
    prior,
    sample_transforms=sample_transforms,
    likelihood_transforms=likelihood_transforms,
    n_loop_training=n_loop_training,
    n_loop_production=1,
    n_local_steps=5,
    n_global_steps=100,
    n_chains=100,
    n_epochs=n_epochs,
    learning_rate=learning_rate,
    n_max_examples=30,
    n_flow_samples=100,
    momentum=0.9,
    batch_size=100,
    use_global=True,
    train_thinning=1,
    output_thinning=1,
    local_sampler_arg=local_sampler_arg,
    strategies=["default"],
)

print("Start sampling")
key = jax.random.PRNGKey(42)
jim.sample(key)
jim.print_summary()
samples = jim.get_samples()

for transform in likelihood_transforms:
    samples = jax.vmap(transform.forward)(samples)

import matplotlib
matplotlib.use("agg")
matplotlib.rcParams.update(
    {'font.size': 16,
     'text.usetex': True,
     'font.family': 'Times New Roman'}
)
import matplotlib.pyplot as plt
plt.figure(1)
plt.xlim([5, 53])
plt.ylim([1.8, 18])
# drawing mass ratio lines
import numpy as np
x = np.linspace(0., 100.)
plt.plot(x, 1 * x, color='k', linestyle='--')
plt.plot(x, 0.125 * x, color='k', linestyle='--')
plt.plot(x, Mc_m1_to_m2(5., x)[0].real, color='k', linestyle='--')
plt.plot(x, Mc_m1_to_m2(15., x)[0].real, color='k', linestyle='--')
plt.scatter(samples['m_1'], samples['m_2'])
plt.xlabel(r'$m_1 [M_\odot]$')
plt.ylabel(r'$m_2 [M_\odot]$')
plt.savefig('test_m1_m2.png', bbox_inches='tight')
tsunhopang commented 1 month ago

Added the transform into the Pv2 testing script and updated the Mc_m1_to_m2 function