inducer / pymbolic

A simple package to do symbolic math (focus on code gen and DSLs)
http://mathema.tician.de/software/pymbolic
Other
108 stars 25 forks source link

Applicability for Image Domain Delayed Operation Problem #45

Open Erotemic opened 3 years ago

Erotemic commented 3 years ago

I'm currently working on code to perform "delayed operations" on images. After working at it for awhile I'm realizing that what I'm writing is really an expression tree, and what I'm trying to do is simplify that tree before executing it.

A simplified layout of available operations are:

An example tree might look like:

    +- Crop
    |
    +- Cat
       |
       +- Warp
       |  |
       |  +- Cat
       |     |
       |     +- Crop
       |     |  |
       |     |  +- Load
       |     |
       |     +- Warp
       |        |
       |        +- Load
       |
       +- Warp
           |
           +- Load

or as a expression C(A(W(A(C(L), W(L))), W(L)))

However, if I was to simply execute this graph it would be inefficient. The crop at the root is likely going to throw away a lot of the intermediate computation. So in practice it might be better to rewrite this tree as:

    +- Cat
       |
       +- Warp
       |  |
       |  +- Crop
       |     |
       |      +- Load
       |  
       +- Warp
       |  |
       |  +- Crop 
       |     |
       |     + Load
       |
       +- Warp
           |
           +- Crop
              |
              +- Load

and as an expression A(W(C(L)), W(C(L)), W(C(L))). (note I'm abusing notation, each operation may have different parameters)

Where all of the cat operations have been moved to the root (because cat is associative A(a, A(b, c)) = A(A(a, b), c)), all of the warp neighboring warp operations have been squashed together (they are all affine, so just multiply the matrices), and the crops have all moved towards the leafs (For an expression C(W(L)), we can compute a new crop and warp W' such that C(W(L)) = W'(C'(L)).

I may have not explained this exactly correctly, but the intuition is that all cats can be delayed until the end, all crops should happen immediately after a load, and all neighboring warps can be combined into a single operation.

In this new simplified expression tree all the crops occur immediately after a load operation, so you are only working with pixels that will influence the final result.

I'm currently slogging through the logic to implement this, but I'm at the point where I'm just going to have to hack it because I don't have a strong grasp on how to implement this properly. I did a bit of searching for resources or packages such that I might be able to define how parameters to operations change when you swap their order in the tree, and then hopefully that package might have a simplify operation that could compute the final structure I'm actually interested in in the general case.

I was wondering if this package might be used for something like that.

inducer commented 3 years ago

Sure, I think that could work. It's easy to define your own node types and to code term rewriting/traversals to do the operator squashing "fusion" (see the mini-tutorial in the docs). And there's also arithmetic between nodes if you need it.

Erotemic commented 3 years ago

Thanks for the pointer. I played around with it a little bit, and it seems to work roughly the way I would expect. So far I did this:

"""
Testing:
    https://documen.tician.de/pymbolic/
"""
import pymbolic as pmbl
from pymbolic.mapper import IdentityMapper
from pymbolic.primitives import Expression
import numpy as np

class AutoInspectable(object):
    """
    Helper to provide automatic defaults for pymbolic expressions
    """

    def init_arg_names(self):
        return tuple(self._initkw().keys())

    def __getinitargs__(self):
        return tuple(self._initkw().values())

    def _initkw(self):
        import inspect
        from collections import OrderedDict
        sig = inspect.signature(self.__class__)
        initkw = OrderedDict()
        for name, info in sig.parameters.items():
            if not hasattr(self, name):
                raise NotImplementedError((
                    'Unable to introspect init args because the class '
                    'did not have attributes with the same names as the '
                    'constructor arguments'))
            initkw[name] = getattr(self, name)
        return initkw

class AutoExpression(AutoInspectable, Expression):
    pass

class Warp(AutoExpression):
    def __init__(self, sub_data, transform):
        self.sub_data = sub_data
        self.transform = transform

    mapper_method = "map_warp"

class ChanCat(AutoExpression):
    def __init__(self, components):
        self.components = components

    mapper_method = "map_chancat"

class RawImage(AutoExpression):
    def __init__(self, data):
        self.data = data

    mapper_method = "map_raw"

class WarpFusionMapper(IdentityMapper):
    def map_warp(self, expr):
        if isinstance(expr.sub_data, Warp):
            # Fuse neighboring warps
            t1 = expr.transform
            t2 = expr.sub_data.transform
            new_tf = t1 @ t2
            new_subdata = self.rec(expr.sub_data.sub_data)
            new = Warp(new_subdata, new_tf)
            return new
        elif isinstance(expr.sub_data, ChanCat):
            # A warp followed by a ChanCat becomes a ChanCat followed by that
            # warp
            tf = expr.transform
            new_components = []
            for comp in expr.sub_data.components:
                new_comp = Warp(comp, tf)
                new_components.append(new_comp)
            new = ChanCat(new_components)
            new = self.rec(new)
            return new
        else:
            return expr

    def map_chancat(self, expr):
        # ChanCat is associative
        new_components = []

        def _flatten(comps):
            for c in comps:
                if isinstance(c, ChanCat):
                    yield from _flatten(c.components)
                else:
                    yield c
        new_components = [self.rec(c) for c in _flatten(expr.components)]
        new = ChanCat(new_components)
        return new

    def map_raw(self, expr):
        return expr

class Transform:
    # temporary transform class for easier to read outputs in POC
    def __init__(self, f):
        self.f = f

    def __matmul__(self, other):
        return Transform(self.f * other.f)

    def __str__(self):
        return 'Tranform({})'.format(self.f)

    def __repr__(self):
        return 'Tranform({})'.format(self.f)

raw1 = RawImage('image1')
w1_a = Warp(raw1, Transform(2))
w1_b = Warp(w1_a, Transform(3))

raw2 = RawImage('image2')
w2_a = Warp(raw2, Transform(5))
w2_b = Warp(w2_a, Transform(7))

raw3 = RawImage('image3')
w3_a = Warp(raw3, Transform(11))
w3_b = Warp(w3_a, Transform(13))

cat1 = ChanCat([w1_b, w2_b])

warp_cat = Warp(cat1, Transform(17))

cat2 = ChanCat([warp_cat, w3_b])

mapper = WarpFusionMapper()

print('cat2    = {!r}'.format(cat2))
result1 = mapper(cat2)
print('result1 = {!r}'.format(result1))
result2 = mapper(result1)
print('result2 = {!r}'.format(result2))
result3 = mapper(result2)
print('result3 = {!r}'.format(result3))

This gives me:

cat2    = ChanCat([Warp(ChanCat([Warp(Warp(RawImage('image1'), Tranform(2)), Tranform(3)), Warp(Warp(RawImage('image2'), Tranform(5)), Tranform(7))]), Tranform(17)), Warp(Warp(RawImage('image3'), Tranform(11)), Tranform(13))])
result1 = ChanCat([ChanCat([Warp(Warp(RawImage('image1'), Tranform(2)), Tranform(51)), Warp(Warp(RawImage('image2'), Tranform(5)), Tranform(119))]), Warp(RawImage('image3'), Tranform(143))])
result2 = ChanCat([Warp(RawImage('image1'), Tranform(102)), Warp(RawImage('image2'), Tranform(595)), Warp(RawImage('image3'), Tranform(143))])
result3 = ChanCat([Warp(RawImage('image1'), Tranform(102)), Warp(RawImage('image2'), Tranform(595)), Warp(RawImage('image3'), Tranform(143))])

Each call to the mapper seems to only some of the possible reductions on the tree. I have to call it twice for it to finish fusing and shuffling everything. Calling it the third time doesn't change anything, which is expected.

Is there a recommended way to accomplish "full simplification?". Additionally, if you have time to look over my implementation to check to see if I'm adhering to best practices, or if I'm using the library in a way I shouldn't be, I'd appreciate it.

inducer commented 3 years ago

Is there a recommended way to accomplish "full simplification?"

I think this depends on how you write the individual map_... routines. Potentially, the output of self.rec(...) will allow for further simplification, but you'd need to check for that.