symforce-org / symforce

Fast symbolic computation, code generation, and nonlinear optimization for robotics
https://symforce.org
Apache License 2.0
1.44k stars 147 forks source link

Look into differences between SymPy and SymEngine CSE, and optimizations='basic' #38

Open aaron-skydio opened 2 years ago

hayk-skydio commented 2 years ago

From sympy/simplify/cse_main.py noting that optimizations=basic includes two things:

basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post),
                       (factor_terms, None)]

Applying both of those individually to the example of (pose.inverse() * point).jacobian(pose), jotting down notes:

Copying what the "sub" optimization does (in sympy/simplify/cse_opts.py):



def sub_pre(e):
    """ Replace y - x with -(x - y) if -1 can be extracted from y - x.
    """
    # replacing Add, A, from which -1 can be extracted with -1*-A
    adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()]
    reps = {}
    ignore = set()
    for a in adds:
        na = -a
        if na.is_Mul:  # e.g. MatExpr
            ignore.add(a)
            continue
        reps[a] = Mul._from_args([S.NegativeOne, na])

    e = e.xreplace(reps)

    # repeat again for persisting Adds but mark these with a leading 1, -1
    # e.g. y - x -> 1*-1*(x - y)
    if isinstance(e, Basic):
        negs = {}
        for a in sorted(e.atoms(Add), key=default_sort_key):
            if a in ignore:
                continue
            if a in reps:
                negs[a] = reps[a]
            elif a.could_extract_minus_sign():
                negs[a] = Mul._from_args([S.One, S.NegativeOne, -a])
        e = e.xreplace(negs)
    return e

def sub_post(e):
    """ Replace 1*-1*x with -x.
    """
    replacements = []
    for node in preorder_traversal(e):
        if isinstance(node, Mul) and \
            node.args[0] is S.One and node.args[1] is S.NegativeOne:
            replacements.append((node, -Mul._from_args(node.args[2:])))
    for node, replacement in replacements:
        e = e.xreplace({node: replacement})

    return e```