arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.59k stars 393 forks source link

InferenceData coords for NumPyro plates #2022

Open ColdTeapot273K opened 2 years ago

ColdTeapot273K commented 2 years ago

Tell us about it

When creating InferenceData using

az.from_numpyro(...)

the resulting autogenerated coordinates are note very telling: image

Now consider the NumPyro model below that produced these samples. These coords with autogenerated names are in fact plate dimensions. And plates have names.


import numpy as np
import pandas as pd

# INFO: PPL specific imports

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoLaplaceApproximation, AutoNormal

from jax import lax, random
from jax.scipy.special import expit

import arviz as az

# %%

data_uri = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/NWOGrants.csv"
df_dev = pd.read_csv(data_uri, sep=";")
df_dev.head()

df_dev["gender"] = df_dev["gender"] == "m"
df_dev["gender"] = df_dev["gender"].astype(int)
df_dev["discipline"] = df_dev["discipline"].astype("category").cat.codes

# %%

def model(data: pd.DataFrame, observed=True):
    applications = data["applications"].values
    awards = data["awards"].values

    discipline = data["discipline"].values
    discipline_card = np.unique(discipline).shape[0]
    gender = data["gender"].values
    gender_card = np.unique(gender).shape[0]

    observations_card = data.shape[0]

    # INFO: good plate version
    with numpyro.plate("plate_gender", gender_card):
        with numpyro.plate("plate_discipline", discipline_card):
            alpha_gender_discipline = numpyro.sample("alpha_gender_discipline", dist.Normal(-1, 1))

    assert alpha_gender_discipline.shape == (9, 2)

    link_p = numpyro.deterministic("link_p", alpha_gender_discipline[discipline, gender])

    with numpyro.plate("plate_observations", observations_card):
        numpyro.sample(
            "awards", dist.Binomial(total_count=applications, logits=link_p), obs=awards if observed else None
        )

kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=5000,
    num_chains=1,
    progress_bar=True,
)
mcmc.run(random.PRNGKey(0), df_dev)
samples = mcmc.get_samples()

az.from_numpyro(mcmc)

Thoughts on implementation

It would be handy if from_numpyro could extract those sites from numpyro model.

I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).

OriolAbril commented 2 years ago

That would be nice. Do you know how the dimensions are stored in the numpyro model or mcmc object?

If it is possible to get a dictionary with variable names as keys and a list of the dimensions/plates as values it would be straightforward to implement. the pymc converter does that already

I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).

This sounds like a confusion between dimensions and coordinates. It uses cmdstanpy, but maybe this blogpost I wrote can help clarify the difference.

kylejcaron commented 5 months ago

That would be nice. Do you know how the dimensions are stored in the numpyro model or mcmc object?

If it is possible to get a dictionary with variable names as keys and a list of the dimensions/plates as values it would be straightforward to implement. the pymc converter does that already

I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).

This sounds like a confusion between dimensions and coordinates. It uses cmdstanpy, but maybe this blogpost I wrote can help clarify the difference.

Leaving some notes here, I might be interested in picking this up later on if I have time but I implemented something similar in my own code recently and figured I'd share it. I'll hopefully have time to revisit this.

Dims are stored in numpyro in the model (as opposed to in the mcmc object) and tend to get captured by plates, which typically represent independent draws across dimensions.

def model(X,y=None):
     ...
     with numpyro.plate("categories", n_categories):
          alpha = numpyro.sample("alpha", dist.Normal(0, 5))
     ....

They can be pulled out of the model as follows:

plates = numpyro.infer.inspect.get_model_relations(
    model,
    model_args=model_args, 
    model_kwargs=model_kwargs
)['plate_sample']

# I haven't really checked this, its probably wrong when there are multiple dims but its a start
dims = {
    value:[key]
    for key,lst in plates.items()
    for value in lst
}

Notice that it also requires *model_args, and **model_kwargs - this means the input for the model has to get included as well for the inspect tool to work. This would require the model to be passed in as an optional input AND the model inputs which makes it a complicated pattern that I wouldn't recommend:

idata = az.from_numpyro(mcmc, model, *model_args, **model_kwargs)

I'm guessing there's probably a way to get the plate dims without the model args and kwargs, but I haven't figured out how yet. That would simplify the example above to a much more reasonable state:

idata = az.from_numpyro(mcmc, model)

However, there can also be dependent draws across dimensions and these aren't represented with plates. For instance, a categorical variable represented as a zerosumnormal where there's a dim for each of the 50 us states isn't represented with a plate in numpyro since the zerosumnormal has dependent dims due to the zerosumconstraint.

def model(X, y=None):
     ...
     b_state = numpyro.sample("b_state", dist.ZeroSumNormal(scale=1, event_shape=(50,))
     ....

This wouldn't get captured by the example above. Not sure if there is a solution for that