pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 232 forks source link

Low effective sample size #1527

Closed MaAl13 closed 1 year ago

MaAl13 commented 1 year ago

Hello everybody, i am relatively new to Bayesian inference and wanted to try out a hierarchical model. Basically what i want to do is to estimate parameters for individual patients, similar to the docu. The Forward problem is defined by an ODE. The code below is running, but the estimates are just completely off. This is probably due to a high Rhat and a low ESS. Can anybody tell me how to fix this? Also some tips about speed up would really be appreciated!

import dill
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from diffrax import ConstantStepSize, Dopri5, ODETerm, SaveAt, diffeqsolve
from jax import random
from numpyro.infer import MCMC, NUTS  # , Predictive
from sklearn.preprocessing import LabelEncoder

# from fem_cycle_model.fetch_params import fetch_params
# from fem_cycle_model.main import run_model
pd.options.mode.chained_assignment = None  # default='warn'
assert numpyro.__version__.startswith("0.11.0")

# Select the number of cores that numpyro will use
numpyro.set_host_device_count(4)

# Von A01 bis A026 und B01 bis B024
# function that merges dataframes
def merge_dataframes(dataframes):
    df = pd.DataFrame()
    for x in dataframes:
        df = pd.concat([df, x])
    return df

# insert a column in a dataframe
def insert_column(df, column_name, column_values):
    df.insert(0, column_name, column_values)
    return df

# make an array with n tmes the same string
def make_array(string, n):
    array = []
    for i in range(n):
        array.append(string)
    return array

# Write function that creates the patient Ids for the females A01 until A026 and B01 until B024 for the cycles 1 and 2
def string_id_generator():
    list = []
    for j in range(1, 3):
        for i in range(1, 27):
            list.append("A" + str(i).zfill(2) + "_" + str(j))
        for i in range(1, 25):
            list.append("B" + str(i).zfill(2) + "_" + str(j))
    return list

def string_desplitter(code):
    return code.split("_")[0], int(code.split("_")[1])

def test_ode(t, z, theta):
    """
    Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
    describes the interaction of two species.
    """
    u = z[0]
    v = z[1]
    alpha, beta, gamma, delta = (
        theta[0],
        theta[1],
        theta[2],
        theta[3],
    )
    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt])

# Create a linspace of time points 0, 0.2, 0.4, ...

def sim(params, patient_id=None):
    term = ODETerm(test_ode)
    solver = Dopri5()
    saveat = SaveAt(ts=np.linspace(0, 10, 50))
    stepsize_controller = ConstantStepSize()
    sol = diffeqsolve(
        term,
        solver,
        t0=0,
        t1=10,
        dt0=0.1,
        y0=jnp.array([1, 0.2]),
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        args=params,
    )
    if patient_id is not None:
        df = pd.DataFrame(sol.ys)
        df = df.rename(columns={0: "dom", 1: "non_dom"})
        df = insert_column(df, "patient_id", patient_id)
        return df
    else:
        return sol.ys[:, 0], sol.ys[:, 1]

patient_IDs = string_id_generator()
patient_IDs_test = patient_IDs[0:3]

def run_model_data(patient_IDs, params):
    df = pd.DataFrame()
    for i, patient_id in enumerate(patient_IDs):
        params_i = jnp.array(params) * (i + 1)
        df = pd.concat([df, sim(params_i, patient_id)])
    return df

patient_data = run_model_data(patient_IDs_test, [2, 3, 4, 5])

def run_model_all(params):
    non_dom_foll_array = jnp.array([])
    dom_foll_array = jnp.array([])
    for i in range(len(params[1])):
        dom_foll, non_dom_foll = sim([params[j][i] for j in range(len(params))])
        non_dom_foll_array = jnp.append(non_dom_foll_array, non_dom_foll)
        dom_foll_array = jnp.append(dom_foll_array, dom_foll)
    return dom_foll_array, non_dom_foll_array

encoder = LabelEncoder()
encoder.fit(patient_IDs_test)
N_PATIENTS = len(encoder.classes_)

def model(patient_code, dom_foll=None, non_dom_foll=None):
    μ_foll_alpha = numpyro.sample("μ_foll_alpha", dist.Uniform(1.0, 7.0))
    σ_foll_alpha = numpyro.sample("σ_foll_alpha", dist.LeftTruncatedDistribution(dist.Normal(3.0, 2.0)))

    μ_foll_beta = numpyro.sample("μ_foll_beta", dist.Uniform(2.0, 10.0))
    σ_foll_beta = numpyro.sample("σ_foll_beta", dist.LeftTruncatedDistribution(dist.Normal(4.0, 2.0)))

    μ_foll_gamma = numpyro.sample("μ_foll_gamma", dist.Uniform(3.0, 13.0))
    σ_foll_gamma = numpyro.sample("σ_foll_gamma", dist.LeftTruncatedDistribution(dist.Normal(5.0, 2.0)))

    μ_foll_delta = numpyro.sample("μ_foll_delta", dist.Uniform(4.0, 16.0))
    σ_foll_delta = numpyro.sample("σ_foll_delta", dist.LeftTruncatedDistribution(dist.Normal(7.0, 2.0)))

    σ = numpyro.sample("σ", dist.LeftTruncatedDistribution(dist.Normal(0.3, .5)))

    with numpyro.plate("plate_i", N_PATIENTS):
        foll_alpha = numpyro.sample("foll_alpha", dist.Normal(μ_foll_alpha, σ_foll_alpha))
        foll_beta = numpyro.sample("foll_beta", dist.Normal(μ_foll_beta, σ_foll_beta))
        foll_gamma = numpyro.sample("foll_gamma", dist.Normal(μ_foll_gamma, σ_foll_gamma))
        foll_delta = numpyro.sample("foll_delta", dist.Normal(μ_foll_delta, σ_foll_delta))

    foll_dom_est, foll_non_dom_est = run_model_all(
        [
            foll_alpha[patient_code],
            foll_beta[patient_code],
            foll_gamma[patient_code],
            foll_delta[patient_code],
        ]
    )
    numpyro.sample("obs_dom", dist.Normal(foll_dom_est, σ), obs=dom_foll)
    numpyro.sample("obs_non_dom", dist.Normal(foll_non_dom_est, σ), obs=non_dom_foll)

data_dict = dict(
    dom_foll=jnp.array(patient_data.dom),
    non_dom_foll=jnp.array(patient_data.non_dom),
    # log_radon_test=jnp.array(100*jnp.ones(len(df2.log_radon))),
)

# Transform all of the counties into integers, which will be used as the county variable in the model
patient_code = jnp.array(encoder.transform(patient_IDs_test))

# Add the county variable to the data dictionary
data_dict.update({"patient_code": patient_code})

# Specify the number of chains in the Markov Chain Monte Carlo. Typically set to the nmber of cores in the computer
mcmc_kwargs = dict(num_samples=2000, num_warmup=2000, num_chains=4)

# Select a random key and split it into different parts. This guarantees that we get the same result each time and
# will lead to reproducable results. For more see:
# https://ericmjl.github.io/dl-workshop/02-jax-idioms/03-deterministic-randomness.html
rng_key = random.PRNGKey(12)
seed1, seed2, seed3, seed4, seed5 = random.split(rng_key, 5)

inference_mcmc = MCMC(NUTS(model, init_strategy=numpyro.infer.init_to_sample()), **mcmc_kwargs)
inference_mcmc.run(seed1, **data_dict)

# This block lets the posterior be pickled
inference_mcmc.sampler._sample_fn = None  # pylint: disable=protected-access
inference_mcmc.sampler._init_fn = None  # pylint: disable=protected-access
inference_mcmc.sampler._postprocess_fn = None  # pylint: disable=protected-access
inference_mcmc.sampler._potential_fn = None  # pylint: disable=protected-access
inference_mcmc.sampler._potential_fn_gen = None  # pylint: disable=protected-access
inference_mcmc._cache = {}  # pylint: disable=protected-access

# Saving the posterior
with open("savemcmc.pkl", "wb") as f:
    dill.dump(inference_mcmc, f)

print(inference_mcmc.print_summary())

The summary is the following:

               mean       std    median      5.0%     95.0%     n_eff     r_hat

foll_alpha[0] 1.51 0.67 1.71 0.28 2.27 2.45 2.27 foll_alpha[1] 1.82 1.21 1.37 0.51 3.88 2.21 3.15 foll_alpha[2] 1.90 0.78 1.73 0.65 3.19 3.05 1.86 foll_beta[0] 2.43 0.92 2.58 0.68 3.57 3.02 1.69 foll_beta[1] 3.62 1.69 3.24 1.35 6.16 2.60 2.03 foll_beta[2] 4.55 1.98 4.11 1.53 7.75 3.30 1.68 foll_delta[0] 6.57 2.44 5.74 3.78 10.42 3.59 1.51 foll_delta[1] 13.43 3.39 12.96 8.34 18.96 13.30 1.11 foll_delta[2] 14.33 3.72 14.27 7.72 20.07 15.41 1.17 foll_gamma[0] 5.40 2.29 4.59 2.86 8.96 3.52 1.53 foll_gamma[1] 10.36 2.82 9.95 6.02 15.04 11.74 1.12 foll_gamma[2] 12.35 3.17 12.26 6.85 17.35 14.76 1.16 μ_foll_alpha 2.13 1.01 1.85 1.00 3.45 10.88 1.11 μ_foll_beta 4.06 1.62 3.65 2.00 6.41 12.86 1.10 μ_foll_delta 10.79 2.95 11.01 6.57 15.95 397.68 1.01 μ_foll_gamma 8.83 2.34 9.01 5.43 12.83 76.82 1.03 σ 0.46 0.02 0.46 0.43 0.49 51.93 1.04 σ_foll_alpha 1.65 1.32 1.29 0.01 3.52 7.84 1.16 σ_foll_beta 2.78 1.70 2.62 0.02 5.10 10.99 1.11 σ_foll_delta 6.64 1.81 6.60 3.72 9.64 1975.99 1.01 σ_foll_gamma 4.93 1.64 4.83 2.17 7.46 917.31 1.02

Number of divergences: 36

fehiepsi commented 1 year ago

@MaAl13 Could you post the question in our forum instead? https://forum.pyro.ai/ We only use github to track issues and features.