Open ColdTeapot273K opened 2 years ago
That would be nice. Do you know how the dimensions are stored in the numpyro model or mcmc object?
If it is possible to get a dictionary with variable names as keys and a list of the dimensions/plates as values it would be straightforward to implement. the pymc converter does that already
I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).
This sounds like a confusion between dimensions and coordinates. It uses cmdstanpy, but maybe this blogpost I wrote can help clarify the difference.
That would be nice. Do you know how the dimensions are stored in the numpyro model or mcmc object?
If it is possible to get a dictionary with variable names as keys and a list of the dimensions/plates as values it would be straightforward to implement. the pymc converter does that already
I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).
This sounds like a confusion between dimensions and coordinates. It uses cmdstanpy, but maybe this blogpost I wrote can help clarify the difference.
Leaving some notes here, I might be interested in picking this up later on if I have time but I implemented something similar in my own code recently and figured I'd share it. I'll hopefully have time to revisit this.
Dims are stored in numpyro in the model (as opposed to in the mcmc object) and tend to get captured by plates, which typically represent independent draws across dimensions.
def model(X,y=None):
...
with numpyro.plate("categories", n_categories):
alpha = numpyro.sample("alpha", dist.Normal(0, 5))
....
They can be pulled out of the model as follows:
plates = numpyro.infer.inspect.get_model_relations(
model,
model_args=model_args,
model_kwargs=model_kwargs
)['plate_sample']
# I haven't really checked this, its probably wrong when there are multiple dims but its a start
dims = {
value:[key]
for key,lst in plates.items()
for value in lst
}
Notice that it also requires *model_args, and **model_kwargs - this means the input for the model has to get included as well for the inspect tool to work. This would require the model to be passed in as an optional input AND the model inputs which makes it a complicated pattern that I wouldn't recommend:
idata = az.from_numpyro(mcmc, model, *model_args, **model_kwargs)
I'm guessing there's probably a way to get the plate dims without the model args and kwargs, but I haven't figured out how yet. That would simplify the example above to a much more reasonable state:
idata = az.from_numpyro(mcmc, model)
However, there can also be dependent draws across dimensions and these aren't represented with plates. For instance, a categorical variable represented as a zerosumnormal where there's a dim for each of the 50 us states isn't represented with a plate in numpyro since the zerosumnormal has dependent dims due to the zerosumconstraint.
def model(X, y=None):
...
b_state = numpyro.sample("b_state", dist.ZeroSumNormal(scale=1, event_shape=(50,))
....
This wouldn't get captured by the example above. Not sure if there is a solution for that
Tell us about it
When creating
InferenceData
usingthe resulting autogenerated coordinates are note very telling:
Now consider the NumPyro model below that produced these samples. These coords with autogenerated names are in fact plate dimensions. And plates have names.
Thoughts on implementation
It would be handy if
from_numpyro
could extract those sites from numpyro model.I'll note that providing custom coords like so
az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])})
produces no effect (coords don't get renamed or anything). For now I stick to.rename({"alpha_gender_dim_0": "gender"})
.