egraphs-good / egglog

egraphs + datalog!
https://egraphs-good.github.io/egglog/
MIT License
400 stars 45 forks source link

User Defined Generics #386

Open saulshanabrook opened 2 months ago

saulshanabrook commented 2 months ago

I am opening up this issue to discuss how to support user-defined generics within egglog.

Motiviation

Now that first-class functions are working in the egglog Python bindings and I am starting to use them, it's becoming harder to avoid having user defined generics. For example, here is a somewhat minimal tuple of int's class:

class TupleInt(Expr, ruleset=array_api_ruleset):
    EMPTY: ClassVar[TupleInt]

    def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...

    @classmethod
    def single(cls, i: Int) -> TupleInt:
        return TupleInt(Int(1), lambda _: i)

    @classmethod
    def range(cls, stop: Int) -> TupleInt:
        return TupleInt(stop, lambda i: i)

    @classmethod
    def from_vec(cls, vec: Vec[Int]) -> TupleInt: ...

    def __add__(self, other: TupleInt) -> TupleInt:
        return TupleInt(
            self.length() + other.length(),
            lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()]),
        )

    def length(self) -> Int: ...
    def __getitem__(self, i: IntLike) -> Int: ...

    @method(preserve=True)
    def __len__(self) -> int:
        return int(self.length())

    @method(preserve=True)
    def __iter__(self) -> Iterator[Int]:
        return iter(self[i] for i in range(len(self)))

    def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...

    def fold_boolean(self, init: Boolean, f: Callable[[Boolean, Int], Boolean]) -> Boolean: ...

    def contains(self, i: Int) -> Boolean:
        return self.fold_boolean(FALSE, lambda acc, j: acc | (i == j))

@array_api_ruleset.register
def _tuple_int(
    i: Int,
    i2: Int,
    k: i64,
    f: Callable[[Int, Int], Int],
    bool_f: Callable[[Boolean, Int], Boolean],
    idx_fn: Callable[[Int], Int],
    vs: Vec[Int],
    b: Boolean,
):
    return [
        rewrite(TupleInt(i, idx_fn).length()).to(i),
        rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)),
        # From vec
        rewrite(TupleInt.from_vec(vs).length()).to(Int(vs.length())),
        rewrite(TupleInt.from_vec(vs)[Int(k)]).to(vs[k]),
        # fold
        rewrite(TupleInt(0, idx_fn).fold(i, f)).to(i),
        rewrite(TupleInt(i, idx_fn).fold(i, f)).to(
            f(TupleInt(i - 1, lambda i: idx_fn(i + 1)).fold(i, f), idx_fn(Int(0))),
            eq(i == Int(0)).to(FALSE),
        ),
        # fold boolean
        rewrite(TupleInt(0, idx_fn).fold_boolean(b, bool_f)).to(b),
        rewrite(TupleInt(i, idx_fn).fold_boolean(b, bool_f)).to(
            bool_f(TupleInt(i - 1, lambda i: idx_fn(i + 1)).fold_boolean(b, bool_f), idx_fn(Int(0))),
            eq(i == Int(0)).to(FALSE),
        ),
        # Empty
        rewrite(TupleInt.EMPTY).to(TupleInt(0, bottom_indexing)),
    ]

@function
def bottom_indexing(i: Int) -> Int: ...

This code has to be repeated for every other tuple (tuple of booleans, tuple of ndarrays, etc).

Not only that, you can see I have two repeated fold functions, one for folding a tuple of ints to ints, and another to bools. The burden is high enough that it makes a whole host of functional programming like tasks quite a chore instead of being relatively smooth.

Without user defined generics, it also makes it much less attactive to provide any shareable stdlib like module, with all of these collections, since they are much less general purpose.

A generic Tuple class could look something like this:

from typing import Generic, TypeVar

T = TypeVar("T", bound=Expr)
V = TypeVar("V", bound=Expr)

class Tuple(Expr, Generic[T], ruleset=array_api_ruleset):
    def __init__(self, length: IntLike, idx_fn: Callable[[Int], T]) -> None: ...

    @classmethod
    def empty(cls) -> Tuple[T]:
        return Tuple(Int(0), cls._never_index)

    @classmethod
    def _never_index(cls, i: Int) -> T: ...

    @classmethod
    def single(cls, i: T) -> Tuple[T]:
        return Tuple(Int(1), lambda _: i)

    @classmethod
    def from_vec(cls, vec: Vec[T]) -> Tuple[T]: ...

    def __add__(self, other: Tuple[T]) -> Tuple[T]:
        return Tuple(
            self.length() + other.length(),
            lambda i: (i < self.length()).if_(self[i], other[i - self.length()]),
        )

    def length(self) -> Int: ...
    def __getitem__(self, i: IntLike) -> T: ...

    @method(preserve=True)
    def __len__(self) -> int:
        return int(self.length())

    @method(preserve=True)
    def __iter__(self) -> Iterator[T]:
        return iter(self[i] for i in range(len(self)))

    def fold(self, init: V, f: Callable[[V, T], V]) -> V: ...

    def contains(self, i: T, eq_fn: Callable[[T, T], Boolean]) -> Boolean:
        return self.fold(FALSE, lambda acc, j: acc | eq_fn(i, j))

@array_api_ruleset.register
def _tuple(
    i: Int,
    i2: Int,
    k: i64,
    f: Callable[[V, T], V],
    idx_fn: Callable[[Int], T],
    vs: Vec[Int],
    v: V,
):
    return [
        rewrite(Tuple(i, idx_fn).length()).to(i),
        rewrite(Tuple(i, idx_fn)[i2]).to(idx_fn(i2)),
        # From vec
        rewrite(Tuple.from_vec(vs).length()).to(Int(vs.length())),
        rewrite(Tuple.from_vec(vs)[Int(k)]).to(vs[k]),
        # fold
        rewrite(Tuple(0, idx_fn).fold(v, f)).to(v),
        rewrite(Tuple(i, idx_fn).fold(v, f)).to(
            f(Tuple(i - 1, lambda i: idx_fn(i + 1)).fold(v, f), idx_fn(Int(0))),
            eq(i == Int(0)).to(FALSE),
        ),
    ]

Approach

I understand it would be a rather large lift to add them to egglog itself. I started a rough design in this PR https://github.com/egraphs-good/egglog/pull/299 A suggestion there was to try doing all of this as desugaring instead in Python package.

That approach would work something like creating a duplicate actual sort for each instantiation of a generic type and also duplicating all rewrite rules.

I am still a bit unclear on how the details would work though. I assume when we add rules and then run rulesets, we would have to make many instances of each generic type and every rewrite rule that uses it. But it seems hard to know when to stop generating new rules.

For example, if there is a rule with facts that depend on Generic1[A] and Generic1[B], and actions that give Generic2[A, B] do we parametrize A and B with every single sort that is registered? How would we limit the infinite recursion then? (i.e Generic1[Generic1[int]], etc).

If anyone has a sound way of thinking about it, that would be much appreciated.

saulshanabrook commented 1 month ago

For example, if there is a rule with facts that depend on Generic1[A] and Generic1[B], and actions that give Generic2[A, B] do we parametrize A and B with every single sort that is registered? How would we limit the infinite recursion then? (i.e Generic1[Generic1[int]], etc).

To make this more specific, I came up with a "reasonable" example of a rewrite that might cause this kind of recursion:

# mypy: disable-error-code="empty-body"
from __future__ import annotations

from collections.abc import Callable

from egglog import *

type IntLike = i64Like | Int

class Int(Expr):
    def __init__(self, i: i64Like) -> None: ...
    def __add__(self, other: IntLike) -> Int: ...
    def __gt__(self, other: IntLike) -> Boolean: ...
    def __lt__(self, other: IntLike) -> Boolean: ...
    def max(self, other: IntLike) -> Int: ...

converter(i64, Int, Int)

class Boolean(Expr):
    def if_[T](self, t: T, f: T) -> T: ...

@function
def eq_fn[T](a: T, b: T) -> Boolean: ...

class Product[T, V](Expr):
    def __init__(self, left: T, right: V) -> None: ...
    @property
    def left(self) -> T: ...
    @property
    def right(self) -> V: ...
    def set_left(self, value: T) -> Product[T, V]: ...
    def set_right(self, value: V) -> Product[T, V]: ...

    def eq(self, other: Product[T, V], l_eq: Callable[[T, T], Boolean], r_eq: Callable[[V, V], Boolean]) -> Boolean: ...

class Option[T](Expr):
    @classmethod
    def none(cls) -> Option[T]: ...
    def __init__(self, v: T) -> None: ...
    def match[V](self, some: Callable[[T], V], none: V) -> V: ...
    def map[V](self, fn: Callable[[T], V]) -> Option[V]: ...

class List[T](Expr):
    def __init__(self) -> None: ...
    def __getitem__(self, index: Int) -> T: ...
    def find_index(self, fn: Callable[[T], Boolean]) -> Option[Int]: ...
    def fold[V](self, f: Callable[[T, V], V], v: V) -> V: ...
    def set(self, index: Int, value: T) -> List[T]: ...
    def append(self, next: T) -> List[T]: ...
    def most_common(self) -> Option[T]:
        """
        Returns the most common element in the list.
        If no items are in the list, returns None.
        If multiple items are tied for most common, returns the first one.
        """
        # 1. Built up list of pairs of elements and their counts
        counts: List[Product[T, Int]] = self.fold(
            lambda x, acc: acc.find_index(lambda p: eq_fn(p.left, x)).match(
                # If we already have a pair for this element, increment the count
                lambda i: acc.set(i, acc[i].set_right(acc[i].right + 1)),
                # Otherwise, add a new pair with a count of 1
                acc.append(Product(x, Int(1))),
            ),
            List[Product[T, Int]](),
        )
        # 2. Find the highest count
        highest_count = counts.fold(
            lambda p, acc: Option(
                acc.match(
                    lambda h: (p.right > h.right).if_(p, h),
                    p,
                )
            ),
            Option[Product[T, Int]].none(),
        )
        return highest_count.map(lambda p: p.left)

def _most_common_definition[T](l: List[T], r: Option[T]):
    res = (
        l.fold(
            lambda x, acc: acc.find_index(lambda p: eq_fn(p.left, x)).match(
                lambda i: acc.set(i, acc[i].set_right(acc[i].right + 1)),
                acc.append(Product(x, Int(1))),
            ),
            List[Product[T, Int]](),
        )
        .fold(
            lambda p, acc: Option(
                acc.match(
                    lambda h: (p.right > h.right).if_(p, h),
                    p,
                )
            ),
            Option[Product[T, Int]].none(),
        )
        .map(lambda p: p.left)
    )
    # the definition of most_common will be turned into this rewrite:
    yield rewrite(l.most_common()).to(res)
    # which will in turn be turned into this rule
    yield rule(eq(r).to(l.most_common())).then(union(r).with_(res))

The way I am understanding the monomorphization approach would be to make multiple version of the most_common rule, one for every possible set of input types.

So if I started with a a type List[Int]() added to the e-graph, and then ran most_common on it, we would definitely need a version of that rule added with T replaced with Int.

Generally, we can think about the facts of a rule as input types and the actions as output types, if we are trying to analyze what types could be further created by applying rules. The only way new types can be added to the e-graph, is in the actions, not in the facts.

So in this rule, in the facts, we have types List[T] and Option[T]. So if we added List[Int]().most_common(some_eq_fn) to the e-graph, then we have types List[Int] and Option[Int], now in the e-graph. So we would monomorphize this function for T being replaced with Int.

But now, in the actions of this function we see we add a number of additional types, including List[Product[Int, Int]] and Option[Product[Int, Int]]. Now if we want to run our monorphization till a fixpoint, we can see we will recurse infinitely here. Now we want to make a most_common where T is Product[Int, Int]. Of course, we can see that we don't actually need this version, because most_common is not called say recursively in the definition of fold or anything. But if we are only looking at the type level and not trying to symbolically evaluate rules to understand which causes which at the value level (which seems infeasible), then we don't know this and we should generate the new definition.

I am only saying all this to give a vaguely realistic example of how this infinite recursion could happen to see if there is any advise on other ways to look at the problem that would allow accurate monorphization of user defined generics. A workaround here would be to cap recursing at a number of levels, say 3, so it would only generate up to three nested definitions or something. However, this would fundamentally be incomplete.

saulshanabrook commented 1 month ago

Chatting with @ezrosent I rewrote this example in Egglog to show how generic types could work there:

;; changes to egglog
;; 1. Allow unbound types in function definitions, these will be generic params
;; 2. Allow a non zero number when defining a sort to indicate the arity of the sort
;; 3. Require sorts with non zero arity to be parameterized, i.e. (List i64)
;; 4. Allow collection primitives to be parameterized inline, i.e. (UnstableFn (i64) i64) instead of requiring them be named
;; 5. Desugar anonymous functions to named functions with a rewrite, i.e. desugar
;;    (lambda (x) (replace x " " "")) to
;;    (function __tmp_name (String) String)
;;    (rewrite (__tmp_name x) (replace x " " ""))

(sort Boolean)

(function if (Boolean T T) T)

(sort Int)

(function IntInit (i64) Int)
(function IntAdd (Int Int) Int)
(function IntGt (Int Int) Boolean)
(function IntLt (Int Int) Boolean)
(function IntMax (Int Int) Int)

(function eq (T T) Boolean)

(sort Product 2)

(function ProductInit (T V) (Product T V))
(function ProductLeft (Product T V) T)
(function ProductRight (Product T V) V)
(function ProductSetLeft (Product T V T) (Product T V))
(function ProductSetRight (Product T V V) (Product T V))

(sort Option 1)

(function OptionNone () (Option T))
(function OptionSome (T) (Option T))

(function OptionMatch ((Option T) (UnstableFn (T) V) V) V)
(function OptionMap ((Option T) (UnstableFn (T) V)) (Option V))

(sort List 1)
(function ListInit () (List T))
(function ListGet ((List T) Int) T)
(function ListFindIndex ((List T) (UnstableFn (T) Boolean)) (Option Int))
(function ListFold ((List T) (UnstableFn (T V) V) V) V)
(function ListSet ((List T) Int T) (List T))
(function ListAppend ((List T) T) (List T))
(function ListMostCommon ((List T)) (Option T))

(rewrite
    (ListMostCommon l)
    (OptionMap
        ;; Find the highest count
        (ListFold
            ;; Built up list of pairs of elements and their counts
            (ListFold
                (lambda (x acc)
                    (OptionMatch
                        (ListFindIndex
                            (lambda (p) (eq (ProductLeft p) x))
                        )
                        ;; If we already have a pair for this element, increment the count
                        (lambda (i)
                            (ListSet acc i (ProductSetRight (ListGet acc i) (IntAdd (ProductRight (ListGet acc i)) (Int 1))))
                        )
                        ;; Otherwise, add a new pair with a count of 1
                        (ListAppend acc (Product x (Int 1)))
                    )
                )
                (ListInit)
            )
            (lambda (p acc)
                (OptionSome
                    (OptionMatch
                        acc
                        (lambda (h)
                            (if (IntGt (ProductRight p) (ProductRight h))
                                p
                                h
                            )
                        )
                        p
                    )
                )
            )
            (OptionNone)
        )
        (lambda (p) (ProductLeft p))
    )
)

In this case, if I add a (sort LintInts (List Int)), then make a list and call most common on it, it should generate a version of the rewrite I wrote lists of ints. If we think about how this rewrite translates to a rule, on the facts of the rule we have types:

In the actions of the rule we have types (at least):

So if we start with types (List Int) and (Option Int), then we should generate types (List (Product Int Int)) and (Option (Product Int Int))... Therefore, we would match on this rule again and recurse...

My assumption here is we are just analyzing every rule at the type level, not looking at any of the actual contents of the rules.

EDIT:

Shorter Example

Here is a shorter example that isn't "realistic" but is much easier to read:

(Sort A 1)

(function CreateA (T) (A T))
(function CreateA2 (T) (A T))

(rule
    ((CreateA x))
    ((CreateA2 (CreateA2 x)))
)

(CreateA 1)

;; If we look at this program, we start with types:
;; (A i64)
;; We see that the rule has (A T) in the facts and (A (A T)) in the body
;; so then we match on (A i64) and create (A (A i64)) and the rule for it:

(Sort A_Int)
(function CreateA_Int (Int) (A_Int Int))
(function CreateA_Int2 (Int) (A_Int Int))

(rule
    ((CreateA_Int x))
    ((CreateA_Int2 (CreateA_Int2 x)))
)

;; However, now we recurse, and since we don't analyze the rule at the function level, only the type level,
;; we don't know that we only need to recurse once, so we would make A_int_int

(sort A_Int_Int)

(function CreateA_Int_Int (A_Int) (A_Int_Int A_Int))
(function CreateA_Int_Int2 (A_Int) (A_Int_Int A_Int))

(rule
    ((CreateA_Int_Int x))
    ((CreateA_Int_Int2 (CreateA_Int_Int2 x)))
)
ezrosent commented 1 month ago

I like the analogy to monomorphization upthread, but I don't think it makes sense to treat monomorphization for egglog as instantiating a function or rule for every type present in the program. In a language like Rust, you'd only create a function like CreateA Int if CreateA is actually called on an integer.

What's weird with egglog is that functions aren't explicitly called: they get called "automatically" based on the contents of the database. But I still think the the rules for instantiation should be at the 'function-level', not the type-level as Saul's pointing out here. We talked offline about having a sort of fixpoint computation of which types are needed. One way to think about doing that is to rewrite the rules where all the values are erased, but the types show up explicitly, e.g.:

(rule ((CreateA T)) ((CreateA2 (CreateA2 T))))

(rule ((CreateA_Int)) ((CreateA_Int2)))

(rule ((CreateA_Int_Int)) ((CreateA_Int2_Int2)))

I think for any seed values from the database, these rules will saturate and we can read off the contents of the "type-level" CreateA and CreateA2 databases to see which types we need to instantiate for the functions with type params, and then go on the translate them.

One thing that we'd need to do here is to do here is to take steps to guarantee saturation. A few ideas:

yihozhang commented 4 weeks ago

I think under this instantiation-as-you-go model, the current design around rule sets and schedules does not play well with constructors like OptionNone. Consider

(sort Option 1)

(function OptionSome (T) (Option T))
(function OptionNone () (Option T))

(rule () ((OptionNone)) :ruleset none-creation)

(OptionSome 1)
(run none-creation 1) ;; which types of T should we instantiate for OptionNone?

(function f () (Option String))
;; should this rule (where T is instantiated to String) fire? according to none-creation rule, this should fire for sure.
;; But in the proposed semantics, whether this would fire depends on
;; whether (Option String) is instantiated before the `none-creation` ruleset is run
(rule ((= e (OptionNone))) ((set (f) e) :ruleset should-this-fire)
(run should-this-fire 1)
saulshanabrook commented 3 weeks ago

@yihozhang I would disallow things like (rule () ((OptionNone)) :ruleset none-creation). I think that every type var should be bound by facts, not left unbound for actions in rules. Like in a similar way that this would be disallowed bc the a var is not constrained and could match anything: (rule ((a)) ((OptionSome a))).

I think if we want to allow the semantics of that rule but with a particular type in mind, we have to allow explicit typing in the frontend.

In the Python bindings, this is required (i.e. I wouldn't allow you to do Option.None() even if the typing could be inferred, instead require an explicit annotation like Option[Int].None() or Option[T].None()). So it would be nice if I could use this info when translating to not require more type inference then necessary.

So we would want to add a way in the syntax to explicitly parameterized the types. I could add this to the examples if it's helpful...


@ezrosent I can see the rough outline of what you mean by doing this at a per function level... Which would allow some more precision that a type level, but as you say, we could still end up the position where it's not entirely precise. Moreso, it seems like a lot of complexity to add.... And makes me more curious about exploring a non-monorphization approach where rules that have type vars are preserved like that into the query compiler, and it's only when the query is matched that actually parameterizes the types?