arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.6k stars 402 forks source link

reloo() appears to ignore `k_thresh` argument #1820

Closed derekpowell closed 3 years ago

derekpowell commented 3 years ago

Describe the bug The (experimental) reloo() function appears to ignore the k_thresh argument. For a loo object with 4 pareto-k > .70 and 2 > 1.0, setting k_thresh=1.0 does not reduce number of times model is refit to 2 as expected.

To Reproduce I encountered this while re-creating the numpyro refitting vignette. I've cut/pasted a reprex at the end of this report.

Expected behavior Expected that setting k_thresh=1.0 would prevent refitting for observations with pareto-k values < 1.0

Additional context Python 3.8.8 arviz version = '0.11.2' numpyro version = '0.6.0'

import arviz as az
import numpyro
import numpyro.distributions as dist
import jax.random as random
from numpyro.infer import MCMC, NUTS
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import xarray as xr

numpyro.set_host_device_count(4)

np.random.seed(26)

xdata = np.linspace(0, 50, 100)
b0, b1, sigma = -2, 1, 3
ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)

def model(N, x, y=None):
    b0 = numpyro.sample("b0", dist.Normal(0, 10))
    b1 = numpyro.sample("b1", dist.Normal(0, 10))
    sigma_e = numpyro.sample("sigma_e", dist.HalfNormal(10))
    numpyro.sample("y", dist.Normal(b0 + b1 * x, sigma_e), obs=y)

data_dict = {
    "N": len(ydata),
    "y": ydata,
    "x": xdata,
}
kernel = NUTS(model)
sample_kwargs = dict(
    sampler=kernel, 
    num_warmup=1000, 
    num_samples=1000, 
    num_chains=4, 
    chain_method="parallel"
)
mcmc = MCMC(**sample_kwargs)
mcmc.run(random.PRNGKey(0), **data_dict)

dims = {"y": ["time"], "x": ["time"]}
idata_kwargs = {
    "dims": dims,
    "constant_data": {"x": xdata}
}
idata = az.from_numpyro(mcmc, **idata_kwargs)

class NumPyroSamplingWrapper(az.SamplingWrapper):
    def __init__(self, model, **kwargs):        
        self.model_fun = model.sampler.model
        self.rng_key = kwargs.pop("rng_key", random.PRNGKey(0))

        super(NumPyroSamplingWrapper, self).__init__(model, **kwargs)

    def log_likelihood__i(self, excluded_obs, idata__i):
        samples = {
            key: values.values.reshape((-1, *values.values.shape[2:]))
            for key, values 
            in idata__i.posterior.items()
        }
        log_likelihood_dict = numpyro.infer.log_likelihood(
            self.model_fun, samples, **excluded_obs
        )
        if len(log_likelihood_dict) > 1:
            raise ValueError("multiple likelihoods found")
        data = {}
        nchains = idata__i.posterior.dims["chain"]
        ndraws = idata__i.posterior.dims["draw"]
        for obs_name, log_like in log_likelihood_dict.items():
            shape = (nchains, ndraws) + log_like.shape[1:]
            data[obs_name] = np.reshape(log_like.copy(), shape)
        return az.dict_to_dataset(data)[obs_name]

    def sample(self, modified_observed_data):
        self.rng_key, subkey = random.split(self.rng_key)
        mcmc = MCMC(**self.sample_kwargs)
        mcmc.run(subkey, **modified_observed_data)
        return mcmc

    def get_inference_data(self, fit):
        # Cloned from PyStanSamplingWrapper.
        idata = az.from_numpyro(mcmc, **self.idata_kwargs)
        return idata

class LinRegWrapper(NumPyroSamplingWrapper):
    def sel_observations(self, idx):
        xdata = self.idata_orig.constant_data["x"].values
        ydata = self.idata_orig.observed_data["y"].values
        mask = np.isin(np.arange(len(xdata)), idx)
        data__i = {"x": xdata[~mask], "y": ydata[~mask], "N": len(ydata[~mask])}
        data_ex = {"x": xdata[mask], "y": ydata[mask], "N": len(ydata[mask])}
        return data__i, data_ex

loo_orig = az.loo(idata, pointwise=True)

loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])

numpyro_wrapper = LinRegWrapper(
    mcmc, 
    rng_key=random.PRNGKey(5),
    idata_orig=idata, 
    sample_kwargs=sample_kwargs, 
    idata_kwargs=idata_kwargs
)

loo_relooed = az.reloo(numpyro_wrapper, loo_orig=loo_orig, k_thresh=1.0) # model refits 4 times, expected 2 times
derekpowell commented 3 years ago

Just saw #1580 and reinstalled dev version, which resolved the problem!