aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 153 forks source link

Merge `Scan`'s inner-graphs without graph traversal #777

Open brandonwillard opened 2 years ago

brandonwillard commented 2 years ago

The current form of Scan.__eq__ requires an expensive traversal of two Scan Ops' inner-graphs in order to determine equivalence (i.e. equal_computations).

The underlying problem is that the Scan Op carries its own Aesara graph, and, in order to determine whether or not two Ops are equal, we need to compare those graphs.

Normally, Aesara is able to use identity-based comparisons of graphs within the same FunctionGraph (where all this is most important) by "merging" sub-graphs when they're equivalent (see MergeOptimizer). This has the effect of replacing duplicates by a single unique instance so that is-based comparisons are all that's necessary.

Unfortunately, the graphs inside Scan Ops are not actually a part of the graphs they're contained within (i.e. a graph that uses the Scan Op), so they are inaccessible to the FunctionGraphs that undergo rewriting and are not "merged" (or canonicalized for that matter).

Another issue is that comparisons of Scans' inner-graphs need to be performed under alpha equivalence (i.e. variable names/labels shouldn't matter). This is only an issue when two distinct but equivalent Scan Ops are created outside of a FunctionGraph (e.g. by a user). In this case, identity-based comparisons is not relevant, because each Scan's inner-graph's inputs are necessarily distinct, owner/parent-less Variables—even after "merging".

~Instead, we should use something like the original inner-graph function (i.e. the function object provided by the user that generated the inner-graph itself) to compare instances of Scans. This approach wouldn't be as general as the current one, but it would be significantly more performant and it should cover the most relevant equivalence checks (i.e. the ones that occur during the manipulation of a FunctionGraph.)~

~The only immediate downside I can think of right now would be less robust caching (e.g. the same function renamed or slightly changed wouldn't match an equivalent cached version), but we should think about this a bit more.~

brandonwillard commented 2 years ago

After a little more thought, I've realized that using the Python function object would complicate rewrites, because each rewrite of a Scan would require a new corresponding Python function object. It could be hacked in some way to—for instance—not rely so much on the Python function object itself and label mutations via simple accumulated identifiers, but this is too indirect and complicated.

I'm starting to think that the best approach is to add Scan's inner-graphs to the FunctionGraph somehow.

This could be accomplished with special handling in FunctionGraph, or by moving Scan's inner-graphs into its inputs. I'm partial to the latter, because it involves less special-casing, so I'll focus on that approach for now.

Doing so would allow Scan's to be identified (or "merged") and compared through conventional means (as well as canonicalized), and this would be consistent throughout all the rewriting.

The alpha equivalence issue is separate from everything else and requires additional functionality. For this, I think it might be possible to implement a special "nominal" Variable type that's really just a set of singletons that serve as a proxy for their underlying Types (i.e. they are only differentiated by the Types to which they uniquely correspond).

Next, we need to consider the ramifications of moving inner-graphs to Scan's inputs.

The first thing to consider: an inner-graph introduces new "scalar" inputs for each of the sequences it iterates over and/or lags/leads. These new inputs would show up as inputs to the entire graph, and we don't want that.

(NB: These inner-graph inputs are also the same Variables that would need to be treated as nominal terms.)

brandonwillard commented 2 years ago

Here's a simple example illustrating the effect of naively adding the inner graph to Scan's inputs:

import aesara
import aesara.tensor as at
from aesara.graph.fg import FunctionGraph

a = at.scalar("a")
y = at.vector("y")

def inner_graph_fn(y_t, x_tm1, a):
    x_t = x_tm1 + y_t * a

    x_tm1.name = "x_tm1"
    y_t.name = "y_t"
    x_t.name = "x_t"

    return x_t

z, _ = aesara.scan(
    inner_graph_fn,
    sequences=[y],
    outputs_info={"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
    non_sequences=[a],
)

aesara.dprint(z, depth=3)
# Subtensor{int64::} [id A] ''
#  |for{cpu,scan_fn} [id B] ''
#  | |Subtensor{int64} [id C] ''
#  | |Subtensor{:int64:} [id D] ''
#  | |IncSubtensor{Set;:int64:} [id E] ''
#  | |a [id F]
#  |ScalarConstant{1} [id G]
#
# Inner graphs:
#
# for{cpu,scan_fn} [id B] ''
#  >Elemwise{add,no_inplace} [id H] 'x_t'
#  > |x_tm1 [id I] -> [id E]
#  > |Elemwise{mul,no_inplace} [id J] ''
#  >   |y_t [id K] -> [id D]
#  >   |a_copy [id L] -> [id F]

The nodes with labels C, D, E, and F are all "outer-inputs"—in other words—the actual inputs to the Scan node B (i.e. values that need to be provided in order to compute B).

While the inner-graph is printed, it's also completely ignored by everything in Aesara except the Scan Op itself. This graph has its own inputs with labels I, K, and L, and nice arrows indicating the outer-inputs to which they correspond. (N.B. the outer-input with label C isn't found among the inner-graph because it's the number of steps.)

If we were to add the inner-graph to the "outer-graph" (i.e. the graph in which the Scan itself is contained), its inputs would be added to the outer-graph; however, those are not real inputs that we expect a user to provide—they are provided by the values of the outer-inputs and/or generated recursively by the inner-graph itself.

Here it is illustrated:

scan_var = z.owner.inputs[0]

# Construct a graph using just the `Scan` node's inputs
scan_fg = FunctionGraph(outputs=scan_var.owner.inputs)

# These are the only input dependencies
scan_fg.inputs
# [a, y]

# If we add the `Scan`'s inner graph to the `FunctionGraph`, we see new,
# irrelevant inputs corresponding to the inner-graph inputs
scan_inner_graph_outputs = scan_var.owner.op.outputs
new_scan_fg = FunctionGraph(outputs=scan_inner_graph_outputs + scan_var.owner.inputs)

new_scan_fg.inputs
# [a, y, x_tm1, y_t, a_copy]

This output, and the one before, shows us that the inner-graph constructs its own copies of "non-sequence" variables. This kind of duplication seems especially counterproductive for an in-graph representation of the loop body.

brandonwillard commented 2 years ago

My first thought is that we could make the proposed nominal term type, NominalVariable, also operate like a Constant and never be considered as a FunctionGraph input. Since each inner-graph input would need to be a NominalVariable anyway, this would prevent FunctionGraph from adding them as inputs to itself.

rlouf commented 2 years ago

Back in the days when I was in the process of re-inventing theano (!) I spent a lot of time thinking about how to represent control flow. I came to the conclusion that the graph corresponding to the inner function should be an input of the operator (so should the graph corresponding to the condition function in a while loop) and use placeholder nodes for the function parameters.

I wonder if we should start working on a larger design document for the Scan operator that also includes the changes in the user interface?

brandonwillard commented 2 years ago

Back in the days when I was in the process of re-inventing theano (!) I spent a lot of time thinking about how to represent control flow. I came to the conclusion that the graph corresponding to the inner function should be an input of the operator (so should the graph corresponding to the condition function in a while loop) and use placeholder nodes for the function parameters.

Yes, it's a very natural approach. The actual approach taken by Scan was probably just a seemingly convenient means of avoiding interactions and/or changes that might've involved the broader Theano graph framework.

I wonder if we should start working on a larger design document for the Scan operator that also includes the changes in the user interface?

We have the beginnings of one here that can be updated.

kc611 commented 2 years ago

This could be accomplished with special handling in FunctionGraph, or by moving Scan's inner-graphs into its inputs. I'm partial to the latter, because it involves less special-casing, so I'll focus on that approach for now.

Can the former functionality you're proposing here can be generalized to include things like https://github.com/aesara-devs/aesara/issues/749 ?

Since we're anyway considering adding nodes not part of main graph into FunctionGraph, we might as well have a common framework for handling such 'supporting' graphs. (Like maybe giving FunctionGraph the capability for having 'hidden' inputs and 'hidden' outputs)

brandonwillard commented 2 years ago

Can the former functionality you're proposing here can be generalized to include things like #749 ?

That depends on exactly how we implement this (e.g. whether or not it involves relevant changes to the way FunctionGraph works). I don't think the key changes necessarily need to happen in the FunctionGraph, though. If anything, the most important changes will be in the FunctionGraph compilation process (i.e. the steps following rewrites in a call to aesara.function).

I say this because that's the only place where we would need to ignore parts of a graph (because they're irrelevant to its computation). In other words, we don't want to have to change all the graph walking and rewriting code so that it's made aware of hidden sub-graphs, especially when we can simply ignore irrelevant sub-graphs in one part of the compilation process.

But, yes, they are both very related; I'm just trying to work out how best to do this without too much refactoring and/or adding too much conceptual overhead to Aesara's inner workings.

brandonwillard commented 2 years ago

Actually, a better way to put this: not all sub-graphs should necessarily be compiled (i.e. have thunks created for them).

We could do this by altering the results of FunctionGraph.toposort, so that it doesn't include "hidden" sub-graphs, but this method is used pretty generally throughout the library for rewrite and compilation purposes. There are alternate interfaces for getting the topologically sorted nodes that go by the names "schedule" and "orderings" (e.g. FunctionGraph.orderings, Feature.orderings, Linker.schedule), and they tend to be more computation-specific. Changing one (or both) of those so that it excludes some sub-graphs might be the most appropriate approach, since it wouldn't affect the rewrite process.

One thing I'm trying to work out is where/how the exclusions are specified. Right now, I'm thinking it should be provided by the Apply node in some way, but it could just as well be specified by the Op.

ricardoV94 commented 2 years ago

@brandonwillard is this one done?

brandonwillard commented 2 years ago

@brandonwillard is this one done?

No, but many of the core changes are in. The next steps are in #824, which needs to be updated and finished.