pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.3k stars 245 forks source link

Automatically rendering probabilistic graphical models #949

Closed kpj closed 4 months ago

kpj commented 3 years ago

I was wondering whether there's any way of automatically rendering defined models, similar to pymc3.model_graph.model_to_graphviz.

If there's currently no way of doing this, I'd be happy to submit a PR (if anyone else is also interested in this functionality). One could either go for the pygraphviz approach or use daft instead.

I'd be very happy if someone could point me towards a part of numpyro's codebase which would make extracting the underlying network the easiest. Rendering only a subset of numpyro's primitives would be an acceptable start for me. I'd probably go for numpyro.sample and numpyro.plate first.

fritzo commented 3 years ago

See also https://github.com/pyro-ppl/pyro/issues/1502

fehiepsi commented 3 years ago

This would be a very nice feature. Currently, you can obtain sample <-> plate relationship using trace:

import jax
import jax.numpy as jnp

import numpyro.distributions as dist
import numpyro
from numpyro import handlers

def model(data):
    x = numpyro.sample("x", dist.Normal(0, 1))
    sd = numpyro.sample("sd", dist.LogNormal(0, 1))
    with numpyro.plate("N", len(data)):
        numpyro.sample("obs", dist.Normal(x, sd), obs=data)

data = jnp.ones(10)
trace = handlers.trace(handlers.seed(model, 0)).get_trace(data)
trace

which gives us

...
             ('obs',
              {'type': 'sample',
               'name': 'obs',
               'fn': <numpyro.distributions.distribution.ExpandedDistribution at 0x7f53a4412970>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32),
               'scale': None,
               'is_observed': True,
               'intermediates': [],
               'cond_indep_stack': [CondIndepStackFrame(name='N', dim=-1, size=10)],
               'infer': {}})])

As you can see, the obs site has the field cond_indep_stack, which is a list of plates that obs belongs to.

That is a good start I guess. I'll think a bit more about how to extract sample <-> sample relationship. One way is to use grad information.

def log_probs(sample):
    with handlers.trace() as tr, handlers.seed(model, 0), handlers.substitute(data=sample):
        model(data)
    return {name: site["fn"].log_prob(site["value"])
            for name, site in tr.items() if site["type"] == "sample" and name not in sample}

jax.jacobian(log_probs)({"x": trace["x"]["value"]})

which returns

{'obs': {'x': DeviceArray([27.513908, 27.513908, 27.513908, 27.513908, 27.513908,
               27.513908, 27.513908, 27.513908, 27.513908, 27.513908],            dtype=float32)},
 'sd': {'x': DeviceArray(0., dtype=float32)}}

So we know that obs depends on x but sd is not.

fritzo commented 3 years ago

Nice trick! I guess the grad trick will work for reparametrized continuous distributions, but not discrete distributions.

kpj commented 3 years ago

Thank you both so much, your suggestions were extremely helpful!

Using your code as reference, I put together a little proof-of-concept which generates the following figure for the model you specified above. image (one of many caveats: nested plates are not supported)

Should I clean up the code a bit and create a PR, so we can find some bugs? :-) If yes, where should such plotting functions go?

fehiepsi commented 3 years ago

I guess the grad trick will work for reparametrized continuous distributions, but not discrete distributions.

You are right. We will be able to find the relation continuous -> discrete but not for the reversed side. I guess we can mimic the pattern, to generate 2 values for that discrete variable (with the remaining variables fixed) and see which sites have log_prob changed.

which generates the following figure for the model

Awesome, @kpj! Looks beautiful to me.

Should I clean up the code a bit and create a PR, so we can find some bugs?

I think you will want to play a bit more with different modeling patterns, before a PR. For example, the code above does not give you the relation a -> b -> c (it gave you a -> b, a -> c, b -> c instead). I guess you can use the following code

import jax
import jax.numpy as jnp

import numpyro.distributions as dist
import numpyro
from numpyro import handlers

def model(data):
    x = numpyro.sample("x", dist.Normal(0, 1))
    sd = numpyro.sample("sd", dist.LogNormal(x, 1))
    with numpyro.plate("N", len(data)):
        numpyro.sample("obs", dist.Normal(x, sd), obs=data)

def get_relations(model, args=(), kwargs={}, num_tries=10):
    """
    :param int num_tries: times to trace model to detect discrete -> continuous dependency
    """
    trace = handlers.trace(handlers.seed(model, 0)).get_trace(*args, **kwargs)
    obs_sites = [name for name, site in trace.items()
                 if site["type"] == "sample" and site["is_observed"]]
    plate_deps = {}
    for name, site in trace.items():
        if site["type"] == "sample":
            plate_deps[name] = [(frame.name, frame.dim) for frame in site["cond_indep_stack"]]
    plate_sites = [name for name, site in trace.items() if site["type"] == "plate"]
    plate_samples = {k: [name for name in plate_deps
                     if k in [p[0] for p in plate_deps[name]]] for k in plate_sites}

    def get_log_probs(sample, seed=0):
        with handlers.trace() as tr, handlers.seed(model, seed), handlers.substitute(data=sample):
            model(*args, **kwargs)
        return {name: site["fn"].log_prob(site["value"])
                for name, site in tr.items() if site["type"] == "sample"}

    samples = {name: site["value"] for name, site in trace.items()
               if site["type"] == "sample" and not site["is_observed"]
               and not site["fn"].is_discrete}
    log_prob_grads = jax.jacobian(get_log_probs)(samples)
    sample_deps = {}
    for name, grads in log_prob_grads.items():
        sample_deps[name] = {n for n in grads if n != name and (grads[n] != 0).any()}

    # find discrete -> continuous dependency
    samples = {name: site["value"] for name, site in trace.items() if site["type"] == "sample"}
    discrete_sites = [name for name, site in trace.items() if site["type"] == "sample"
                      and not site["is_observed"] and site["fn"].is_discrete]
    log_probs_prototype = get_log_probs(samples)
    for name in discrete_sites:
        samples_ = samples.copy()
        samples_.pop(name)
        for i in range(num_tries):
            log_probs = get_log_probs(samples_, seed=i + 1)
            for var in samples:
                if var == name:
                    continue
                if (log_probs[var] != log_probs_prototype[var]).any():
                    sample_deps[var] |= {name}
    sample_sample = {}
    for name in samples:
        sample_sample[name] = [var for var in samples if var in sample_deps[name]]
    return {"sample_sample": sample_sample, "sample_plate": plate_deps,
            "plate_sample": plate_samples, "observed": obs_sites}

data = jnp.ones(10)
get_relations(model, (data,))

which gives a better relationship

{'sample_sample': {'obs': ['sd', 'x'], 'sd': ['x'], 'x': []},
 'sample_plate': {'x': [], 'sd': [], 'obs': [('N', -1)]},
 'observed': ['obs']}

I will think a bit more about how to simply detect the discrete -> continuous relationship.

def model(probs, locs):
    c = numpyro.sample("c", dist.Categorical(probs))
    numpyro.sample("x", dist.Normal(locs[c], 0.5))

probs = jnp.array([0.15, 0.3, 0.3, 0.25])
locs = jnp.array([-2, 0, 2, 4])
get_relations(model, probs, locs)
{'sample_sample': {'c': [], 'x': []},
 'sample_plate': {'c': [], 'x': []},
 'observed': []}

In the meantime, could you apply your visualization code with some models in examples folder (e.g. baseball or sparse_regression)?

nested plates are not supported

What is the reason for this? Is it a limitation of pygraphviz/daft? edit: it seems to me that daft supports nested plates :)

kpj commented 3 years ago

I think you will want to play a bit more with different modeling patterns, before a PR. [...] In the meantime, could you apply your visualization code with some models in examples folder (e.g. baseball or sparse_regression)?

Yup, sounds reasonable. Will do after updating the rendering code!

What is the reason for this? Is it a limitation of pygraphviz/daft? edit: it seems to me that daft supports nested plates :)

Oh no, at that point just laziness on my side :-)

I looked into it in more detail, and while it is somewhat awkward to do, it is certainly possible. To create the right nesting, we'd have to know 1) what the nesting structure of each plate is and 2) what level is the deepest for each RV. Am I right in assuming that 2) can be solved by the frame.dim variable in your code (most negative dimension == deepest nesting level) and matching it with the respective name variable? For 1), is there any easy of finding the "parent" plate of a given plate? Finally, can there be plates which are not subsets of each other?

fehiepsi commented 3 years ago

we'd have to know 1) what the nesting structure of each plate is and 2) what level is the deepest for each RV

It seems to me that that info can be inferred from sample_plate field in the output of the above code. frame.dim seems to be unnecessary (but might be useful if you want to display this info in the figure - I don't know). It is a good practice to keep most negative dimension == deepest nesting level but it is not necessarily true. For example, consider the model

def model():
    with numpyro.plate("N", 10, dim=-2):
        x = numpyro.sample("x", dist.Normal(0, 1))
        with numpyro.plate("M", 5, dim=-1):
            y = numpyro.sample("y", dist.Normal(0, 1))
    # with numpyro.plate("M", 5, dim=-1):
    #     z = numpyro.sample("z", dist.Normal(0, 1))
get_relations(model)

gives (I just updated the code to generate plate_sample field)

{'sample_sample': {'x': [], 'y': []},
 'sample_plate': {'x': [('N', -2)], 'y': [('M', -1), ('N', -2)]},
 'plate_sample': {'N': ['x', 'y'], 'M': ['y']},
 'observed': []}

Because plate N contains x and y while plate M only contains y, in this case, M is the deepest nesting level.

To find "parent" plate of a given plate, it might be easier to use plate_sample field. If the list of sample sites at plate M is a subset of the sample sites at plate N, then N is a "father" of M.

I guess currently, your code can handler the following structure?

def model():
    with numpyro.plate("N", 10, dim=-2):
        x = numpyro.sample("x", dist.Normal(0, 1))
        with numpyro.plate("M", 5, dim=-1):
            y = numpyro.sample("y", dist.Normal(0, 1))
    with numpyro.plate("M", 5, dim=-1):
        z = numpyro.sample("z", dist.Normal(0, 1))

get_relations(model)
{'sample_sample': {'x': [], 'y': [], 'z': []},
 'sample_plate': {'x': [('N', -2)],
  'y': [('M', -1), ('N', -2)],
  'z': [('M', -1)]},
 'plate_sample': {'N': ['x', 'y'], 'M': ['y', 'z']},
 'observed': []}

That is two plates have a common site y.

updated: Sadly, this seems not to be an option of graphviz updated2: But it is possible with daft though I am not sure how to set positions of nodes/plates in daft nicely

kpj commented 3 years ago

To find "parent" plate of a given plate, it might be easier to use plate_sample field. If the list of sample sites at plate M is a subset of the sample sites at plate N, then N is a "father" of M.

Yeah, I was thinking of the same approach but I wasn't sure if it would always work. For example in the case you showed above, there's no such subset relation between the red and orange plate. However, for now I simply implemented this heuristic.

I guess currently, your code can handler the following structure?

def model():
    with numpyro.plate("N", 10, dim=-2):
        x = numpyro.sample("x", dist.Normal(0, 1))
        with numpyro.plate("M", 5, dim=-1):
            y = numpyro.sample("y", dist.Normal(0, 1))
    with numpyro.plate("M", 5, dim=-1):
        z = numpyro.sample("z", dist.Normal(0, 1))

get_relations(model)
{'sample_sample': {'x': [], 'y': [], 'z': []},
 'sample_plate': {'x': [('N', -2)],
  'y': [('M', -1), ('N', -2)],
  'z': [('M', -1)]},
 'plate_sample': {'N': ['x', 'y'], 'M': ['y', 'z']},
 'observed': []}

That is two plates have a common site y.

One issue is that two plates have the same name M. The plate_sample field then only contains {'N': ['x', 'y'], 'M': ['y', 'z']} and one plate is lost. When I rename one of them to MM, I get the following figure: image

updated: Sadly, this seems not to be an option of graphviz updated2: But it is possible with daft though I am not sure how to set positions of nodes/plates in daft nicely

One could certainly compute node positions using graphviz and then use daft for plotting. Unfortunately, this wouldn't work for arbitrary plate structures as you noted above. For now, I'd be fine to simply raise an exception for such more complex structures. Do you have an intuition for often they occur in practice?


In addition, I looked at the baseball example. When I generate the relations for the model

def partially_pooled(at_bats, hits=None):
    m = numpyro.sample("m", dist.Uniform(0, 1))
    kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        phi_prior = dist.Beta(m * kappa, (1 - m) * kappa)
        phi = numpyro.sample("phi", phi_prior)
        return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)

I get the following

{'sample_sample': {'kappa': ['obs'],
  'm': ['obs'],
  'obs': ['phi'],
  'phi': ['kappa', 'm', 'obs']},
 'sample_plate': {'m': [],
  'kappa': [],
  'phi': [('num_players', -1)],
  'obs': [('num_players', -1)]},
 'plate_sample': {'num_players': ['phi', 'obs']},
 'observed': []}

The plate information seems correct, but the RV links seem incorrect (e.g. kappa/m should not depend on obs).

fehiepsi commented 3 years ago

The plate_sample field then only contains {'N': ['x', 'y'], 'M': ['y', 'z']} and one plate is lost. When I rename one of them to MM,

Good idea on renaming for visualization! I think it can temporarily be a good workaround. Regarding the plates in that example, there are only 2 plates here: N and M, and y is their common variables.

Do you have an intuition for often they occur in practice?

You are right that it is not common. :+1: (e.g. this model to predict sales across stores and products: we can have some variables for stores and some other variables for products). Raising an exception looks reasonable to me.

the RV links seem incorrect (e.g. kappa/m should not depend on obs).

Thanks! I just fixed the bug there. I haven't filtered out discrete sites correctly in my code.

When I rename one of them to MM, I get the following figure:

The figure is beautiful! I guess your visualization code is pretty robust now. Could you make a PR for this? I'll also ping other devs to the discussion. Your work looks awesome!

I'm working on discrete -> continuous relationship. If it works, we will have beautiful graphics for those annotation models and we can compare against the figures in the corresponding paper. :)

kpj commented 3 years ago

Good idea on renaming for visualization! I think it can temporarily be a good workaround. Regarding the plates in that example, there are only 2 plates here: N and M, and y is their common variables.

I might be misunderstanding, but the variables z and y are not in the same plate (yet both their plates are called M), are they?

Thanks! I just fixed the bug there. I haven't filtered out discrete sites correctly in my code.

Wonderful! The corresponding plot looks quite nice now as well:

The figure is beautiful! I guess your visualization code is pretty robust now. Could you make a PR for this? I'll also ping other devs to the discussion. Your work looks awesome!

Thanks so much! I still have my doubts about its robustness, but I'd be happy to create a PR so we can strive for it :-D Which submodule should I put the code in? I will also create a notebook with some nice exemplary renders.

I'm working on discrete -> continuous relationship. If it works, we will have beautiful graphics for those annotation models and we can compare against the figures in the corresponding paper. :)

Awesome, that's super cool! Looking forward to many beautiful figures :-) I also tried some more examples with, e.g., BetaBinomial distributions and things went very wrong. I assume this is also related to the discrete -> continuous relationships you mention.

fehiepsi commented 3 years ago

the variables z and y are not in the same plate (yet both their plates are called M), are they?

Sorry, I should write the model in the following form to make it clearer (for more complicated pattern, you can see this example, where plate statements are reused in a for loop)

def model():
    N_plate = numpyro.plate("N", 10, dim=-2)
    M_plate = numpyro.plate("M", 5, dim=-1)
    with N_plate:
        x = numpyro.sample("x", dist.Normal(0, 1))
        with M_plate:
            y = numpyro.sample("y", dist.Normal(0, 1))
    with M_plate:
        z = numpyro.sample("z", dist.Normal(0, 1))

Which submodule should I put the code in?

I am not sure... I think numpyro.util is a good place for it but we might also expose it in numpyro.__init__ for convenience (personally, I believe that I'll often use this function)

kpj commented 3 years ago

Sorry, I should write the model in the following form to make it clearer (for more complicated pattern, you can see this example, where plate statements are reused in a for loop)

def model():
    N_plate = numpyro.plate("N", 10, dim=-2)
    M_plate = numpyro.plate("M", 5, dim=-1)
    with N_plate:
        x = numpyro.sample("x", dist.Normal(0, 1))
        with M_plate:
            y = numpyro.sample("y", dist.Normal(0, 1))
    with M_plate:
        z = numpyro.sample("z", dist.Normal(0, 1))

Right, I see. So the correct solution would be to draw two plates which are partially overlapping? Something like

+----+
|    |
|x +-+--+
|  |y|  |
+--+-+  |
   |   z|
   +----+

I am not sure... I think numpyro.util is a good place for it but we might also expose it in numpyro.__init__ for convenience (personally, I believe that I'll often use this function)

Sounds good!

While setting up the PR I noticed that the import from numpyro import handlers makes importing numpyro crash due to circular imports:

ImportError: cannot import name 'COERCIONS' from partially initialized module 'numpyro.distributions.distribution' (most likely due to a circular import) (/numpyro/numpyro/distributions/distribution.py)

What would be the best way of handling this?

fehiepsi commented 3 years ago

draw two plates which are partially overlapping?

That's right!

What would be the best way of handling this?

I'll need to take a look, but as a temporary workaround, you can put extra import statements inside the function

def f():
    import ...

    ...
fehiepsi commented 3 years ago

@kpj I just updated the get_relations implementation above to account for discrete -> continuous relationship. Could you try with your beta-binomial model? Testing on multinomial and dawid_skene models in annotation examples gives me the expected result

# get_relations(multinomial, (annotations,))
{'sample_sample': {'zeta': [], 'pi': [], 'c': ['pi'], 'y': ['zeta', 'c']},
 'sample_plate': {'zeta': [('class', -1)],
  'pi': [],
  'c': [('item', -2)],
  'y': [('position', -1), ('item', -2)]},
 'plate_sample': {'class': ['zeta'], 'item': ['c', 'y'], 'position': ['y']},
 'observed': ['y']}
# get_relations(dawid_skene, (annotators, annotations,))
{'sample_sample': {'beta': [], 'pi': [], 'c': ['pi'], 'y': ['beta', 'c']},
 'sample_plate': {'beta': [('class', -1), ('annotator', -2)],
  'pi': [],
  'c': [('item', -2)],
  'y': [('position', -1), ('item', -2)]},
 'plate_sample': {'annotator': ['beta'],
  'class': ['beta'],
  'item': ['c', 'y'],
  'position': ['y']},
 'observed': ['y']}
kpj commented 3 years ago

I have now created a PR (https://github.com/pyro-ppl/numpyro/pull/952).

As an additional feature it would be cool to (optionally) list the probabilistic distribution of each RV on the side of the figure. Would it be possible to obtain this information in the get_relations function?


I'll need to take a look, but as a temporary workaround, you can put extra import statements inside the function

Thanks, that seems to work for now.

Could you try with your beta-binomial model?

It seems to nicely work with my examples now!

kpj commented 3 years ago

At the moment, the implementation does not support rendering models which make use of the scan function. Here, we continue the discussion from #952.


@fehiepsi

To inspect the model and extract time information, it might be not so complicated. But it will be trickier to render. It is not clear to me how to represent the loops (which can have a long historical dependency: e.g. x{t} depends on x{t-1} and x_{t-2} in ARIMA) in graphics. Do you know which plate graphical notation is used to represent the loops? Probably we can start with history=1 and mimic some scan formats in hmm enum example:

  • put all variables inside scan in a dashed-border rectangular (with label t=0)
  • create the same block with label t=1 and put it side-to-side with block t=0
  • add three dots beside t=1 block
  • add edges from t=0 block to t=1 block
  • add edges from outside-of-time-block nodes to inside-of-time-block nodes

Do you think it is feasible to automatically render such a graph? I can try to extract time information from the model if you want. Rendering the time plate would be truly neat, but it is complicated I guess. :)

@kpj In the general case, I'd just render the whole chain without the ... notation. One could then add a user-defined depth which specifies how many iterations to render before using three dots.

Or is it clear that the "sub-models" created by each scan iteration will always have the same structure? Could a numpyro model potentially do something like this:

def model():
    def foo(carry, i):
        # doesn't work, I know :-)
        if i == 6:
            numpyro.sample('special', dist.Beta(1, 1))
        else:
            numpyro.sample('default', dist.Beta(0.1, 0.1))

        return carry, i

    scan(foo, 0.01, jnp.arange(10))

Besides rendering dots, I also don't know of any better graphical representation.

Ignoring special spatial arrangements, we could probably treat scan as a "special plate", surround it with a differently styled box and put an arrow in-between adjacent boxes. For this we could just introduce a new category and store their ordering (nested scans are not supported, right?).

fehiepsi commented 4 months ago

Closed because the main features are addressed. We can create separate issues for other features like scan.