pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.69k stars 2.01k forks source link

`pm.sampling_jax` doesnt sample properly with `pm.Censored` #5897

Open kylejcaron opened 2 years ago

kylejcaron commented 2 years ago

Description of your problem

When using pm.sampling_jax with pm.Censored, the sampled posterior contains the same value for all draws within each chain. so for instance if I had 4 chains and 1000 draws per chain, the issue is that I'd see output that appears as follows

import numpy as np

example_output = np.c_[
  np.ones(1000)*0.25,
  np.ones(1000)*0.1,
  np.ones(1000)*0.5,
  np.ones(1000)*0.35
]

print(example_output)

Please provide a minimal, self-contained, and reproducible example.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import pymc.sampling_jax
import arviz as az
import jax.numpy as jnp
from aesara.link.jax.dispatch import jax_funcify
from aesara.scalar import Log1mexp

# implement jax Op for Log1mexp

@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
    def log1mexp(x):
        return jnp.where(
            x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
        )

    return log1mexp

np.random.seed(99)
# simulate 500 contexts/units with 100 observations each
contexts = 100
obs_per_context = 100
idxs = np.repeat(range(contexts), obs_per_context)

k_true = np.random.lognormal(0.45, 0.25, contexts)
lambd_true = np.random.lognormal(4.25, 0.5, contexts)

dist = pm.Weibull.dist(k_true[idxs], lambd_true[idxs])
Et = dist.eval()

# Simulate event time data
df_ = pd.DataFrame({
    "group":idxs,
    "event_time":Et
})

# Randomly censor observations
censor_time = np.random.uniform(0,250, size=len(df_))
df = (
    df_
    .assign(censored = lambda df: np.where(df.event_time > censor_time, 1, 0))
    .assign(event_time = lambda df: np.where(df.event_time > censor_time, censor_time, df.event_time) )
)

# Fit model
coords = {"group":df.group.unique()}

with pm.Model(coords=coords) as mW:
    g_ = pm.MutableData("g_", df.group.values)
    y = pm.MutableData("y", df.event_time.values)
    c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.NaN) )

    log_k = pm.Normal("log_k", 0.5, 0.5, dims="group")
    log_lambd = pm.Normal("log_lambd", 4.5, 0.5, dims="group")

    k = pm.Deterministic("k", pm.math.exp(log_k), dims="group")
    lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")
    y_latent = pm.Weibull.dist(k[g_], lambd[g_])
    y_ = pm.Censored("event", y_latent, lower=None, upper=c_, observed=y)
#    # not using pm.censored samples fine with pm.sampling_jax
#     y_ = pm.Weibull("event", k[g_], lambd[g_], observed=y)

with mW:
    idata = pm.sampling_jax.sample_numpyro_nuts()
#     idata = pm.sample(init="adapt_diag") # works normally

# returns 4 - one value for each chain
print(len(
    np.unique(
    idata.posterior["log_k"]
            .to_numpy()[:,:,0])
))

# see warning
az.plot_trace(idata, var_names=["log_k"]);

Please provide the full traceback.

Not applicable - there is no raised error in this case, but the functionality definitely is not working as intended as far as I can tell.

Complete error traceback ```python [The complete error output here] ```

Please provide any additional information below.

Versions and main components

ricardoV94 commented 2 years ago

Hmm can you try to change this line?

c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.inf)

Otherwise those NaN in the graph will cause havok. What happes when you use pm.sample, doesn't it also complain?

kylejcaron commented 2 years ago

Hmm can you try to change this line?

c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.inf)

Otherwise those NaN in the graph will cause havok. What happes when you use pm.sample, doesn't it also complain?

just tested np.inf and it didnt solve the issue - and using pm.sample works fine with the original example and doesnt complain

kylejcaron commented 2 years ago

Whats really weird is that if I add a hierarchical global mean to one of the parameters it relieves some of the behavior (but the sampling still looks pretty off)

for instance if I change this line from my first example

    mu_log_k =pm.Normal("mu_log_k", 0.5, 0.1)

to

    mu_log_k =pm.Normal("mu_log_k", 0.5, 0.1)
    log_k = pm.Normal("log_k", mu_log_k, 0.5, dims="group")

then the sampling is improved but still seems problematic with the following user warning from arviz when plotting a trace (and it seems like there are still some chains that are impacted)

/Users/kcaron/.pyenv/versions/3.9.7/envs/pymc4_env/lib/python3.9/site-packages/arviz/stats/density_utils.py:491: UserWarning: Your data appears to have a single value or no finite values
  warnings.warn("Your data appears to have a single value or no finite values")

Screen Shot 2022-06-14 at 1 27 35 PM

ricardoV94 commented 2 years ago

It sounds like your model is very ill conditioned and Numpyro just does a slightly worse job than PyMC.

Are you sure that np.random.lognormal prior corresponds to the PyMC one? Seems to be correct

Also how many divergences are you getting with pm.sample?

kylejcaron commented 2 years ago

It sounds like your model is very ill conditioned and Numpyro just does a slightly worse job than PyMC.

~Are you sure that np.random.lognormal prior corresponds to the PyMC one?~ Seems to be correct

Also how many divergences are you getting with pm.sample?

The model seems fine from what I know - there are no divergences, high effective sample sizes and rhat is pretty much all rhat=1.0 (I found 3/400 parameters that had r_hat=1.01)

ricardoV94 commented 2 years ago

In that case my best guess is some numerical instabilties/ errors in the jax backend