Closed kpj closed 4 months ago
This would be a very nice feature. Currently, you can obtain sample <-> plate
relationship using trace:
import jax
import jax.numpy as jnp
import numpyro.distributions as dist
import numpyro
from numpyro import handlers
def model(data):
x = numpyro.sample("x", dist.Normal(0, 1))
sd = numpyro.sample("sd", dist.LogNormal(0, 1))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Normal(x, sd), obs=data)
data = jnp.ones(10)
trace = handlers.trace(handlers.seed(model, 0)).get_trace(data)
trace
which gives us
...
('obs',
{'type': 'sample',
'name': 'obs',
'fn': <numpyro.distributions.distribution.ExpandedDistribution at 0x7f53a4412970>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32),
'scale': None,
'is_observed': True,
'intermediates': [],
'cond_indep_stack': [CondIndepStackFrame(name='N', dim=-1, size=10)],
'infer': {}})])
As you can see, the obs
site has the field cond_indep_stack
, which is a list of plates that obs
belongs to.
That is a good start I guess. I'll think a bit more about how to extract sample <-> sample
relationship. One way is to use grad information.
def log_probs(sample):
with handlers.trace() as tr, handlers.seed(model, 0), handlers.substitute(data=sample):
model(data)
return {name: site["fn"].log_prob(site["value"])
for name, site in tr.items() if site["type"] == "sample" and name not in sample}
jax.jacobian(log_probs)({"x": trace["x"]["value"]})
which returns
{'obs': {'x': DeviceArray([27.513908, 27.513908, 27.513908, 27.513908, 27.513908,
27.513908, 27.513908, 27.513908, 27.513908, 27.513908], dtype=float32)},
'sd': {'x': DeviceArray(0., dtype=float32)}}
So we know that obs
depends on x
but sd
is not.
Nice trick! I guess the grad trick will work for reparametrized continuous distributions, but not discrete distributions.
Thank you both so much, your suggestions were extremely helpful!
Using your code as reference, I put together a little proof-of-concept which generates the following figure for the model you specified above. (one of many caveats: nested plates are not supported)
Should I clean up the code a bit and create a PR, so we can find some bugs? :-) If yes, where should such plotting functions go?
I guess the grad trick will work for reparametrized continuous distributions, but not discrete distributions.
You are right. We will be able to find the relation continuous -> discrete
but not for the reversed side. I guess we can mimic the pattern, to generate 2 values for that discrete variable (with the remaining variables fixed) and see which sites have log_prob changed.
which generates the following figure for the model
Awesome, @kpj! Looks beautiful to me.
Should I clean up the code a bit and create a PR, so we can find some bugs?
I think you will want to play a bit more with different modeling patterns, before a PR. For example, the code above does not give you the relation a -> b -> c
(it gave you a -> b, a -> c, b -> c
instead). I guess you can use the following code
import jax
import jax.numpy as jnp
import numpyro.distributions as dist
import numpyro
from numpyro import handlers
def model(data):
x = numpyro.sample("x", dist.Normal(0, 1))
sd = numpyro.sample("sd", dist.LogNormal(x, 1))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Normal(x, sd), obs=data)
def get_relations(model, args=(), kwargs={}, num_tries=10):
"""
:param int num_tries: times to trace model to detect discrete -> continuous dependency
"""
trace = handlers.trace(handlers.seed(model, 0)).get_trace(*args, **kwargs)
obs_sites = [name for name, site in trace.items()
if site["type"] == "sample" and site["is_observed"]]
plate_deps = {}
for name, site in trace.items():
if site["type"] == "sample":
plate_deps[name] = [(frame.name, frame.dim) for frame in site["cond_indep_stack"]]
plate_sites = [name for name, site in trace.items() if site["type"] == "plate"]
plate_samples = {k: [name for name in plate_deps
if k in [p[0] for p in plate_deps[name]]] for k in plate_sites}
def get_log_probs(sample, seed=0):
with handlers.trace() as tr, handlers.seed(model, seed), handlers.substitute(data=sample):
model(*args, **kwargs)
return {name: site["fn"].log_prob(site["value"])
for name, site in tr.items() if site["type"] == "sample"}
samples = {name: site["value"] for name, site in trace.items()
if site["type"] == "sample" and not site["is_observed"]
and not site["fn"].is_discrete}
log_prob_grads = jax.jacobian(get_log_probs)(samples)
sample_deps = {}
for name, grads in log_prob_grads.items():
sample_deps[name] = {n for n in grads if n != name and (grads[n] != 0).any()}
# find discrete -> continuous dependency
samples = {name: site["value"] for name, site in trace.items() if site["type"] == "sample"}
discrete_sites = [name for name, site in trace.items() if site["type"] == "sample"
and not site["is_observed"] and site["fn"].is_discrete]
log_probs_prototype = get_log_probs(samples)
for name in discrete_sites:
samples_ = samples.copy()
samples_.pop(name)
for i in range(num_tries):
log_probs = get_log_probs(samples_, seed=i + 1)
for var in samples:
if var == name:
continue
if (log_probs[var] != log_probs_prototype[var]).any():
sample_deps[var] |= {name}
sample_sample = {}
for name in samples:
sample_sample[name] = [var for var in samples if var in sample_deps[name]]
return {"sample_sample": sample_sample, "sample_plate": plate_deps,
"plate_sample": plate_samples, "observed": obs_sites}
data = jnp.ones(10)
get_relations(model, (data,))
which gives a better relationship
{'sample_sample': {'obs': ['sd', 'x'], 'sd': ['x'], 'x': []},
'sample_plate': {'x': [], 'sd': [], 'obs': [('N', -1)]},
'observed': ['obs']}
I will think a bit more about how to simply detect the discrete -> continuous
relationship.
def model(probs, locs):
c = numpyro.sample("c", dist.Categorical(probs))
numpyro.sample("x", dist.Normal(locs[c], 0.5))
probs = jnp.array([0.15, 0.3, 0.3, 0.25])
locs = jnp.array([-2, 0, 2, 4])
get_relations(model, probs, locs)
{'sample_sample': {'c': [], 'x': []},
'sample_plate': {'c': [], 'x': []},
'observed': []}
In the meantime, could you apply your visualization code with some models in examples folder (e.g. baseball
or sparse_regression
)?
nested plates are not supported
What is the reason for this? Is it a limitation of pygraphviz
/daft
? edit: it seems to me that daft
supports nested plates :)
I think you will want to play a bit more with different modeling patterns, before a PR. [...] In the meantime, could you apply your visualization code with some models in examples folder (e.g.
baseball
orsparse_regression
)?
Yup, sounds reasonable. Will do after updating the rendering code!
What is the reason for this? Is it a limitation of
pygraphviz
/daft
? edit: it seems to me thatdaft
supports nested plates :)
Oh no, at that point just laziness on my side :-)
I looked into it in more detail, and while it is somewhat awkward to do, it is certainly possible. To create the right nesting, we'd have to know 1) what the nesting structure of each plate is and 2) what level is the deepest for each RV. Am I right in assuming that 2) can be solved by the frame.dim
variable in your code (most negative dimension == deepest nesting level) and matching it with the respective name
variable?
For 1), is there any easy of finding the "parent" plate of a given plate?
Finally, can there be plates which are not subsets of each other?
we'd have to know 1) what the nesting structure of each plate is and 2) what level is the deepest for each RV
It seems to me that that info can be inferred from sample_plate
field in the output of the above code. frame.dim
seems to be unnecessary (but might be useful if you want to display this info in the figure - I don't know). It is a good practice to keep most negative dimension == deepest nesting level
but it is not necessarily true. For example, consider the model
def model():
with numpyro.plate("N", 10, dim=-2):
x = numpyro.sample("x", dist.Normal(0, 1))
with numpyro.plate("M", 5, dim=-1):
y = numpyro.sample("y", dist.Normal(0, 1))
# with numpyro.plate("M", 5, dim=-1):
# z = numpyro.sample("z", dist.Normal(0, 1))
get_relations(model)
gives (I just updated the code to generate plate_sample
field)
{'sample_sample': {'x': [], 'y': []},
'sample_plate': {'x': [('N', -2)], 'y': [('M', -1), ('N', -2)]},
'plate_sample': {'N': ['x', 'y'], 'M': ['y']},
'observed': []}
Because plate N
contains x
and y
while plate M
only contains y
, in this case, M
is the deepest nesting level.
To find "parent" plate of a given plate, it might be easier to use plate_sample
field. If the list of sample sites at plate M
is a subset of the sample sites at plate N
, then N
is a "father" of M
.
I guess currently, your code can handler the following structure?
def model():
with numpyro.plate("N", 10, dim=-2):
x = numpyro.sample("x", dist.Normal(0, 1))
with numpyro.plate("M", 5, dim=-1):
y = numpyro.sample("y", dist.Normal(0, 1))
with numpyro.plate("M", 5, dim=-1):
z = numpyro.sample("z", dist.Normal(0, 1))
get_relations(model)
{'sample_sample': {'x': [], 'y': [], 'z': []},
'sample_plate': {'x': [('N', -2)],
'y': [('M', -1), ('N', -2)],
'z': [('M', -1)]},
'plate_sample': {'N': ['x', 'y'], 'M': ['y', 'z']},
'observed': []}
That is two plates have a common site y
.
updated: Sadly, this seems not to be an option of graphviz
updated2: But it is possible with daft though I am not sure how to set positions of nodes/plates in daft
nicely
To find "parent" plate of a given plate, it might be easier to use
plate_sample
field. If the list of sample sites at plateM
is a subset of the sample sites at plateN
, thenN
is a "father" ofM
.
Yeah, I was thinking of the same approach but I wasn't sure if it would always work. For example in the case you showed above, there's no such subset relation between the red and orange plate. However, for now I simply implemented this heuristic.
I guess currently, your code can handler the following structure?
def model(): with numpyro.plate("N", 10, dim=-2): x = numpyro.sample("x", dist.Normal(0, 1)) with numpyro.plate("M", 5, dim=-1): y = numpyro.sample("y", dist.Normal(0, 1)) with numpyro.plate("M", 5, dim=-1): z = numpyro.sample("z", dist.Normal(0, 1)) get_relations(model)
{'sample_sample': {'x': [], 'y': [], 'z': []}, 'sample_plate': {'x': [('N', -2)], 'y': [('M', -1), ('N', -2)], 'z': [('M', -1)]}, 'plate_sample': {'N': ['x', 'y'], 'M': ['y', 'z']}, 'observed': []}
That is two plates have a common site
y
.
One issue is that two plates have the same name M
. The plate_sample
field then only contains {'N': ['x', 'y'], 'M': ['y', 'z']}
and one plate is lost. When I rename one of them to MM
, I get the following figure:
updated: Sadly, this seems not to be an option of graphviz updated2: But it is possible with daft though I am not sure how to set positions of nodes/plates in
daft
nicely
One could certainly compute node positions using graphviz and then use daft for plotting. Unfortunately, this wouldn't work for arbitrary plate structures as you noted above. For now, I'd be fine to simply raise an exception for such more complex structures. Do you have an intuition for often they occur in practice?
In addition, I looked at the baseball example. When I generate the relations for the model
def partially_pooled(at_bats, hits=None):
m = numpyro.sample("m", dist.Uniform(0, 1))
kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5))
num_players = at_bats.shape[0]
with numpyro.plate("num_players", num_players):
phi_prior = dist.Beta(m * kappa, (1 - m) * kappa)
phi = numpyro.sample("phi", phi_prior)
return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
I get the following
{'sample_sample': {'kappa': ['obs'],
'm': ['obs'],
'obs': ['phi'],
'phi': ['kappa', 'm', 'obs']},
'sample_plate': {'m': [],
'kappa': [],
'phi': [('num_players', -1)],
'obs': [('num_players', -1)]},
'plate_sample': {'num_players': ['phi', 'obs']},
'observed': []}
The plate information seems correct, but the RV links seem incorrect (e.g. kappa
/m
should not depend on obs
).
The plate_sample field then only contains {'N': ['x', 'y'], 'M': ['y', 'z']} and one plate is lost. When I rename one of them to MM,
Good idea on renaming for visualization! I think it can temporarily be a good workaround. Regarding the plates in that example, there are only 2 plates here: N
and M
, and y
is their common variables.
Do you have an intuition for often they occur in practice?
You are right that it is not common. :+1: (e.g. this model to predict sales across stores and products: we can have some variables for stores and some other variables for products). Raising an exception looks reasonable to me.
the RV links seem incorrect (e.g. kappa/m should not depend on obs).
Thanks! I just fixed the bug there. I haven't filtered out discrete sites correctly in my code.
When I rename one of them to MM, I get the following figure:
The figure is beautiful! I guess your visualization code is pretty robust now. Could you make a PR for this? I'll also ping other devs to the discussion. Your work looks awesome!
I'm working on discrete -> continuous
relationship. If it works, we will have beautiful graphics for those annotation models and we can compare against the figures in the corresponding paper. :)
Good idea on renaming for visualization! I think it can temporarily be a good workaround. Regarding the plates in that example, there are only 2 plates here:
N
andM
, andy
is their common variables.
I might be misunderstanding, but the variables z
and y
are not in the same plate (yet both their plates are called M
), are they?
Thanks! I just fixed the bug there. I haven't filtered out discrete sites correctly in my code.
Wonderful! The corresponding plot looks quite nice now as well:
The figure is beautiful! I guess your visualization code is pretty robust now. Could you make a PR for this? I'll also ping other devs to the discussion. Your work looks awesome!
Thanks so much! I still have my doubts about its robustness, but I'd be happy to create a PR so we can strive for it :-D Which submodule should I put the code in? I will also create a notebook with some nice exemplary renders.
I'm working on
discrete -> continuous
relationship. If it works, we will have beautiful graphics for those annotation models and we can compare against the figures in the corresponding paper. :)
Awesome, that's super cool! Looking forward to many beautiful figures :-)
I also tried some more examples with, e.g., BetaBinomial
distributions and things went very wrong. I assume this is also related to the discrete -> continuous
relationships you mention.
the variables z and y are not in the same plate (yet both their plates are called M), are they?
Sorry, I should write the model in the following form to make it clearer (for more complicated pattern, you can see this example, where plate
statements are reused in a for loop)
def model():
N_plate = numpyro.plate("N", 10, dim=-2)
M_plate = numpyro.plate("M", 5, dim=-1)
with N_plate:
x = numpyro.sample("x", dist.Normal(0, 1))
with M_plate:
y = numpyro.sample("y", dist.Normal(0, 1))
with M_plate:
z = numpyro.sample("z", dist.Normal(0, 1))
Which submodule should I put the code in?
I am not sure... I think numpyro.util
is a good place for it but we might also expose it in numpyro.__init__
for convenience (personally, I believe that I'll often use this function)
Sorry, I should write the model in the following form to make it clearer (for more complicated pattern, you can see this example, where
plate
statements are reused in a for loop)def model(): N_plate = numpyro.plate("N", 10, dim=-2) M_plate = numpyro.plate("M", 5, dim=-1) with N_plate: x = numpyro.sample("x", dist.Normal(0, 1)) with M_plate: y = numpyro.sample("y", dist.Normal(0, 1)) with M_plate: z = numpyro.sample("z", dist.Normal(0, 1))
Right, I see. So the correct solution would be to draw two plates which are partially overlapping? Something like
+----+
| |
|x +-+--+
| |y| |
+--+-+ |
| z|
+----+
I am not sure... I think
numpyro.util
is a good place for it but we might also expose it innumpyro.__init__
for convenience (personally, I believe that I'll often use this function)
Sounds good!
While setting up the PR I noticed that the import from numpyro import handlers
makes importing numpyro crash due to circular imports:
ImportError: cannot import name 'COERCIONS' from partially initialized module 'numpyro.distributions.distribution' (most likely due to a circular import) (/numpyro/numpyro/distributions/distribution.py)
What would be the best way of handling this?
draw two plates which are partially overlapping?
That's right!
What would be the best way of handling this?
I'll need to take a look, but as a temporary workaround, you can put extra import statements inside the function
def f():
import ...
...
@kpj I just updated the get_relations
implementation above to account for discrete -> continuous
relationship. Could you try with your beta-binomial model? Testing on multinomial
and dawid_skene
models in annotation examples gives me the expected result
# get_relations(multinomial, (annotations,))
{'sample_sample': {'zeta': [], 'pi': [], 'c': ['pi'], 'y': ['zeta', 'c']},
'sample_plate': {'zeta': [('class', -1)],
'pi': [],
'c': [('item', -2)],
'y': [('position', -1), ('item', -2)]},
'plate_sample': {'class': ['zeta'], 'item': ['c', 'y'], 'position': ['y']},
'observed': ['y']}
# get_relations(dawid_skene, (annotators, annotations,))
{'sample_sample': {'beta': [], 'pi': [], 'c': ['pi'], 'y': ['beta', 'c']},
'sample_plate': {'beta': [('class', -1), ('annotator', -2)],
'pi': [],
'c': [('item', -2)],
'y': [('position', -1), ('item', -2)]},
'plate_sample': {'annotator': ['beta'],
'class': ['beta'],
'item': ['c', 'y'],
'position': ['y']},
'observed': ['y']}
I have now created a PR (https://github.com/pyro-ppl/numpyro/pull/952).
As an additional feature it would be cool to (optionally) list the probabilistic distribution of each RV on the side of the figure.
Would it be possible to obtain this information in the get_relations
function?
I'll need to take a look, but as a temporary workaround, you can put extra import statements inside the function
Thanks, that seems to work for now.
Could you try with your beta-binomial model?
It seems to nicely work with my examples now!
At the moment, the implementation does not support rendering models which make use of the scan
function.
Here, we continue the discussion from #952.
@fehiepsi
To inspect the model and extract time information, it might be not so complicated. But it will be trickier to render. It is not clear to me how to represent the loops (which can have a long historical dependency: e.g. x{t} depends on x{t-1} and x_{t-2} in ARIMA) in graphics. Do you know which plate graphical notation is used to represent the loops? Probably we can start with history=1 and mimic some scan formats in hmm enum example:
- put all variables inside
scan
in a dashed-border rectangular (with labelt=0
)- create the same block with label
t=1
and put it side-to-side with blockt=0
- add three dots beside
t=1
block- add edges from
t=0
block tot=1
block- add edges from outside-of-time-block nodes to inside-of-time-block nodes
Do you think it is feasible to automatically render such a graph? I can try to extract time information from the model if you want. Rendering the time plate would be truly neat, but it is complicated I guess. :)
@kpj
In the general case, I'd just render the whole chain without the ...
notation.
One could then add a user-defined depth which specifies how many iterations to render before using three dots.
Or is it clear that the "sub-models" created by each scan iteration will always have the same structure? Could a numpyro model potentially do something like this:
def model():
def foo(carry, i):
# doesn't work, I know :-)
if i == 6:
numpyro.sample('special', dist.Beta(1, 1))
else:
numpyro.sample('default', dist.Beta(0.1, 0.1))
return carry, i
scan(foo, 0.01, jnp.arange(10))
Besides rendering dots, I also don't know of any better graphical representation.
Ignoring special spatial arrangements, we could probably treat scan
as a "special plate", surround it with a differently styled box and put an arrow in-between adjacent boxes. For this we could just introduce a new category and store their ordering (nested scans are not supported, right?).
Closed because the main features are addressed. We can create separate issues for other features like scan
.
I was wondering whether there's any way of automatically rendering defined models, similar to
pymc3.model_graph.model_to_graphviz
.If there's currently no way of doing this, I'd be happy to submit a PR (if anyone else is also interested in this functionality). One could either go for the pygraphviz approach or use daft instead.
I'd be very happy if someone could point me towards a part of numpyro's codebase which would make extracting the underlying network the easiest. Rendering only a subset of numpyro's primitives would be an acceptable start for me. I'd probably go for
numpyro.sample
andnumpyro.plate
first.