inducer / pytato

Lazily evaluated arrays in Python
Other
8 stars 16 forks source link

Add a pass to flag arrays only differing in tags #420

Open inducer opened 1 year ago

inducer commented 1 year ago

@majosm reported a situation where a large compile time difference was observed based on an array having a tag vs. not. This is plausible, as even different just tags can lead to arrays not being viewed as equal and therefore failing to be merged in common subexpression elimination. This means that this value (and all its dependents, if both versions are used) are computed multiple times. If multiple uses of the pattern occur, then this could lead to exponential growth of DAG size.

All of this is likely almost always unintended, and so we should at least warn about it (if not error). What I have in mind is a pass that strips all tags and flags the situation in which that process produces multiple versions of the same array that compare equal after stripping.

majosm commented 1 year ago

What I have in mind is a pass that strips all tags and flags the situation in which that process produces multiple versions of the same array that compare equal after stripping.

Where might be a good place to insert this pass? (Not very familiar with the overall structure of pytato yet.)

inducer commented 1 year ago

I think it would come down to adding a function in analysis that perhaps uses a custom WalkMapper for the traversal. That function could then be called from somewhere in the appropriate array context, to actually perform the check.

kaushikcfd commented 1 year ago

Here's one way to do it:

(py311_env) $ cat remove_tags_and_merge.py 
import pytato as pt
import numpy as np
from pytools.tag import Tag

def remove_tag_t(expr, tag_t):
    def _rec_remove_tag_t(expr):
        if isinstance(expr, pt.Array):
            if tags_to_remove := expr.tags_of_type(tag_t):
                return expr.without_tags(tags_to_remove,
                                         verify_existence=False)
            else:
                return expr
        else:
            return expr

    expr = pt.transform.map_and_copy(expr, _rec_remove_tag_t)
    return pt.transform.BranchMorpher()(expr)

x = pt.make_placeholder("x", (10, 4), np.float64)
y = pt.make_placeholder("y", (10, 4), np.float64)

tmp = x + y
tmp1 = tmp.tagged(pt.tags.ImplStored())

out = 2*tmp + 3*tmp1

print(pt.analysis.get_num_nodes(out))
print(pt.analysis.get_num_nodes(remove_tag_t(out, tag_t=Tag)))
(py311_env) $ python remove_tags_and_merge.py 
8
7
kaushikcfd commented 1 year ago

This is plausible, as even different just tags can lead to arrays not being viewed as equal and therefore failing to be merged in common subexpression elimination

This is true, but if it's just one node differing in the tag, then something else is wrong here as the subexpressions for the diverging nodes would still be the same and the relative difference in runtime/compile time should have been insignificant.

inducer commented 1 year ago

This is true, but if it's just one node differing in the tag, then something else is wrong here as the subexpressions for the diverging nodes would still be the same and the relative difference in runtime/compile time should have been insignificant.

Are you sure? Wouldn't depending nodes necessarily also compare non-equal?

inducer commented 1 year ago

Here's one way to do it:

Thanks for providing that! It's quick, but it's got a few downsides: It has quite a few traversals, and it doesn't explicitly identify the offending nodes.

kaushikcfd commented 1 year ago

Are you sure? Wouldn't depending nodes necessarily also compare non-equal?

Aah fair. I was only thinking of the predecessors and not the successors. Thanks for the correction!

Thanks for providing that! It's quick, but it's got a few downsides: It has quite a few traversals, and it doesn't explicitly identify the offending nodes.

Yep, it's a starting point. However, extending it to the functionalities that you point out shouldn't be more than another 50 lines, I think :).

kaushikcfd commented 1 year ago

FWIW, this is more in line with what you suggested:

import pytato as pt
import numpy as np
from typing import Dict

class MyWalkMapper(pt.transform.CachedWalkMapper):
    def __init__(self):
        super().__init__()
        self.stripped_ary_to_ary: Dict[pt.Array, pt.Array] = {}

    def get_cache_key(self, expr):
        return id(expr)

    def post_visit(self, expr: pt.transform.ArrayOrNames):
        if isinstance(expr, pt.Array):
            from pytato.array import (_get_default_tags,
                                      _get_default_axes)
            tagless_expr = expr.copy(
                tags=_get_default_tags(),
                axes=_get_default_axes(expr.ndim))
            try:
                if colliding_expr := self.stripped_ary_to_ary[tagless_expr] != expr:
                    raise ValueError(f"Arrays '{colliding_expr}' and '{expr}'"
                                     " are semantically the same array except"
                                     " the attached metadata => will lead to "
                                     " inefficient generated code.")
            except KeyError:
                self.stripped_ary_to_ary[tagless_expr] = expr

x = pt.make_placeholder("x", (10, 4), np.float64)
y = pt.make_placeholder("y", (10, 4), np.float64)

tmp = x + y
tmp1 = tmp.tagged(pt.tags.ImplStored())

out = 2*tmp + 3*tmp1

MyWalkMapper()(out)