pymc-devs / symbolic-pymc

Tools for the symbolic manipulation of PyMC models, Theano, and TensorFlow graphs.
https://pymc-devs.github.io/symbolic-pymc
Other
61 stars 8 forks source link

Generalize Graph Traversal and Manipulation #6

Closed brandonwillard closed 5 years ago

brandonwillard commented 5 years ago

We need generalized graph traversal and manipulation functions—especially since we cannot use Theano FunctionGraphs and optimizers on TensorFlow graphs. Since the core graph operations will be in miniKanren, these functions might benefit from being in miniKanren, too.

What might this look like? For a reduction relation, a simple recursive miniKanren relation would work. Here's a complete example of such a relation:

import theano
import theano.tensor as tt
from theano.printing import debugprint as tt_dprint

from unification import var

from kanren import run, eq, conde, lall, fact, Relation
from kanren.core import success, fail

from symbolic_pymc.theano.meta import mt
from symbolic_pymc.theano.printing import tt_pprint

_cxx_config = theano.config.cxx
theano.config.cxx = ''

reduces = Relation('reduces')

x_lv = var('x_lv')
y_lv = var('y_lv')
z_lv = var('z_lv')

# x + x -> 2 * x
x_add_mt = mt.add(x_lv, x_lv)
fact(reduces,
        x_add_mt,
        mt.mul(tt.constant(2), x_lv))
# x * x -> x**2
pow_sum_mt = mt.mul(x_lv, x_lv)
fact(reduces,
        pow_sum_mt,
        mt.pow(x_lv, tt.constant(2)))
# -(-x) -> x
fact(reduces,
        mt.neg(mt.neg(x_lv)),
        x_lv)
# exp(log(x)) -> x
fact(reduces,
        mt.exp(mt.log(x_lv)),
        x_lv)
# log(exp(x)) -> x
fact(reduces,
        mt.log(mt.exp(x_lv)),
        x_lv)
# x**y * x**z -> x**(y + z)
pow_mul_mt = mt.mul(mt.pow(x_lv, y_lv),
                    mt.pow(x_lv, z_lv))
fact(reduces,
        pow_mul_mt,
        mt.pow(x_lv,
            mt.add(y_lv, z_lv)))

def reduceo(in_expr, out_expr):
    """Recursive relation for an expression and its reduced form.
    """
    expr_rdcd = var()
    return (conde,
            # Attempt to apply a single reduction
            [(reduces, in_expr, expr_rdcd),
                # If it succeeds, consider another
                (reduceo, expr_rdcd, out_expr)],
            # Return the input unchanged
            [eq(out_expr, in_expr)])

def reduce_expr(input_expr, n=None):
    reduced_expression = var()
    res = run(n, reduced_expression,
                (reduceo, input_expr, reduced_expression))
    return res

Using the above, for a + a we get the following:

a = tt.vector('a')
print(tt_pprint([res.reify() for res in reduce_expr(a + a)]))
a in R**(N^a_0)
(2 * a)
(a + a)

miniKanren produces two results—each corresponding to the branches in conde: the reduced expression and its original. The nature of these results brings up some questions:

  1. Since relation ordering shouldn't matter, how do we want to designate/present the "most reduced" (i.e. fixed point) result in the stream?
  2. In other situations, we might want to order/"weigh" output from the miniKanren stream (i.e. equivalent graph variations); how best do we do that?

There are "impure" versions of conde that perform cut-like operations—effectively stopping at certain conde branches—and those can be used to more efficiently produce fixed-point-like results. There's also this "guided" version of conde called condp that might allow for fewer relational compromises.

The same questions apply to general graph traversal, but also include considerations for how graphs (expression graphs, really) are traversed. In the example above, the graphs aren't really walked:

print(tt_pprint([res.reify() for res in reduce_expr(tt.log(tt.exp(a)) + tt.log(tt.exp(a)))]))
a in R**(N^a_0)
(2 * log(exp(a)))
(log(exp(a)) + log(exp(a)))

Notice how the first—and most reduced—expression has not applied the log(exp(a)) = a reduction. This is because the entire graph does not match said reduction; instead, a subgraph does.

We can, for instance, easily produce the desired result by transforming these objects into their S-expression-like forms and using conde to walk the arguments in cons-fashion (well, after this PR goes through). The condes in this process are also subject to the above questions (or at least similar ones).

brandonwillard commented 5 years ago

For the sake of completion, here's an extension of the previous example that uses a relation that performs a single step of graph-based reduction. It defines a relation on reductions of subgraph-producing expression arguments and then a reduction of the resulting expression.

from kanren.term import operator, arguments
from kanren.assoccomm import buildo

from symbolic_pymc.unify import etuple

def graph_reduceo(in_graph, out_graph):
    """Relation for an expression (sub-)graph reduction."""
    in_op = operator(in_graph)
    in_args = arguments(in_graph)
    res = (lall,)
    reduced_args = etuple()
    for a in in_args:
        a_reduced = var()
        res += ((reduceo, a, a_reduced),)
        reduced_args += (a_reduced,)

    arg_red_graph = var()
    arg_reduce = (buildo, in_op, reduced_args, arg_red_graph)
    res += (arg_reduce,)

    res += ((reduceo, arg_red_graph, out_graph),)

    return res

def reduce_expr(input_expr, n=None):
    reduced_expression = var()
    res = run(n, reduced_expression,
                (graph_reduceo, input_expr, reduced_expression))
    return res

Now, the example from before will reduce as expected:

print(tt_pprint([res.reify() for res in reduce_expr(tt.log(tt.exp(a)) + tt.log(tt.exp(a)))]))
a in R**(N^a_0)
(2 * a)
(log(exp(a)) + a)
(a + log(exp(a)))
(2 * log(exp(a)))
(a + a)
(log(exp(a)) + log(exp(a)))

A conde-based recurrence would likely be sufficient for complete graph reduction/normalization (assuming that the relations and input actually have a fixed-point, of course).

The basics of graph traversal and mutation in miniKanren—instead of standard Python graph traversal (e.g. over meta object graphs) with orchestrated miniKanren use—seems quite simple and succinct, especially with respect to "backtracking"-like capabilities. Likewise, it seems as though the standard graph traversal choices are just as accessible (e.g. depth/breadth, graph-first/args-first, etc.), but we'll have to experiment with those ideas a bit more.

Actually, it seems exceptionally simpler to include a diversity of those graph-processing choices in the context of miniKanren. One could conde branch over different processing orders entirely, and, with something like condp, such choices could be more efficiently chosen/short-circuited!

brandonwillard commented 5 years ago

Side thought: if we know that the graph and relations are normalizing, then we could probably process the conde-produced reduction branches using async generators in the underlying kanren stream processing (e.g. lany and/or lall in this implementation), then terminate when the first normal form is found! I'm not sure how the termination would work across asynchronous goal streams, though.

In other words, I'm talking about a "race to the normal form" between asynchronous streams. I imagine that there are better "meta" approaches to this exact situation (e.g. Knuth-Bendix-like assessments on the relations involved), but this seems like a more general approach—albeit fundamentally limited by async overhead.

@webyrd might have some thoughts on this.

brandonwillard commented 5 years ago

FYI: The fundamentals of these concerns are also expressed in https://github.com/logpy/logpy/issues/3.

brandonwillard commented 5 years ago

I'm just realizing that the condp implementation in https://github.com/logpy/logpy/pull/71 might not directly help in the graph normalization case, since it—for instance—provides no means of assessing cross-cond-branch results. For example, if the first REDUCE branch succeeded—e.g. indicated by expr_rdcd being unified with a reduced form of the input expression—then we would want to use condp to avoid adding a goal to the stream for the fall-back BASE branch (i.e. the one that returns the unchanged expression).

Seems like the standard condu/conda, or a custom—and possibly non-relational—goal, might be the best way to get depth-first and/or short-circuited evaluation for this situation. Either that, or tweak the implementation in https://github.com/logpy/logpy/pull/71 so that it passes incremental/cumulative states to the branch suggestion functions.

Otherwise, one could improve a few edge cases using condp: e.g. if in_expr is an "atom", don't try to reduce it.

brandonwillard commented 5 years ago

While I'm at it, it's worth giving another short demonstration of why the miniKanren/relational approach to graph/model manipulation is worthwhile.

Continuing from the example code above, let's say we want to obtain a very specific reduction, like those with the form log(...) + a, then, because we're working with relations, it's as simple as

test_expr = tt.log(tt.exp(a)) + tt.log(tt.exp(a))
test_res = reduce_expr(test_expr, mt.add(mt.log(var()), a))
print(tt_pprint([res.reify() for res in test_res]))
a in R**(N^a_0)
(log(exp(a)) + a)

Another example is "un-reducing" (i.e. reduce_expr(var(), test_expr)), and—along similar lines—producing all graphs between two expressions containing a mix of logic variables and base terms.