JuliaSymbolics / SymbolicUtils.jl

Symbolic expressions, rewriting and simplification
https://docs.sciml.ai/SymbolicUtils/stable/
Other
524 stars 99 forks source link

Common subexpression elimination #121

Open shashi opened 3 years ago

shashi commented 3 years ago

done with @dpsanders :

using SymbolicUtils
using SymbolicUtils: Sym, Term
using SymbolicUtils.Rewriters
using DataStructures

newsym() = Sym{Number}(gensym("cse"))

function cse(expr)
    dict = OrderedDict()
    r = @rule ~x::(x -> x isa Term) => haskey(dict, ~x) ? dict[~x] : dict[~x] = gensym()
    final = Postwalk(Chain([r]))(expr)
    [[var=>ex for (ex, var) in pairs(dict)]..., final]
end

Examples:

@syms x y
julia> cse(cos(cos(x)) + sin(cos(x)))
5-element Array{Any,1}:
 Symbol("##260") => cos(x)
 Symbol("##261") => cos(##260)
 Symbol("##262") => sin(##260)
 Symbol("##263") => ##261 + ##262
                  Symbol("##263")

julia> SymbolicUtils.show_simplified[] = false
false

julia> cse(cos(cos(x)) + sin(cos(x)))
5-element Array{Any,1}:
 Symbol("##264") => cos(x)
 Symbol("##265") => cos(##264)
 Symbol("##266") => sin(##264)
 Symbol("##267") => ##265 + ##266
                  Symbol("##267")

julia> cse(cos(cos(x)) + cos(cos(x)))
4-element Array{Any,1}:
 Symbol("##268") => cos(x)
 Symbol("##269") => cos(##268)
 Symbol("##270") => ##269 + ##269
                  Symbol("##270")

This issue is a good place to think about some API questions:

  1. What should CSE return
  2. Should we be able to place that thing in a different Term?
MasonProtter commented 3 years ago

So, this is essentially SSA form, right? What if we made a struct SSATerm (or CSETerm if you prefer) that behaves as if it were a Term but actually stores it's contents in this manner.

This would certainly make interop with things like Mjolnir.jl and IRTools.jl easier.

shashi commented 3 years ago

behaves as if it were a Term

The interface for this is:

operation(t)::Function arguments(t)::Vector

I can imagine operation(t) being something like function block end.

But arguments(t) will need to contain assignment operations. So I'm not quite sure how to represent assignments in the world view of terms. I don't think = as the head of the term captures the same meaning as changing an environment.

But, it's perfectly possible to convert this thing to a Term when needed.

So I think we can get most of what we need by adding this as a top-level feature separate from terms, but with the opt-in conversion. Call it BasicBlock or something. And when we have the ability to do conditionals, we can add more such nodes and compose BasicBlocks to form more complex pieces of computation.

cscherrer commented 3 years ago

Hi @shashi , I had modified this to work in my context, but it's no longer working. Could something have changed in the last merge that broke it?

My setup is like this:

function cse(s::Symbolic)
    vars = atoms(s)
    dict = OrderedDict()
    r = @rule ~x => csestep(~x, vars, dict) 
    final = RW.Postwalk(RW.PassThrough(r))(s)
    [[var=>ex for (ex, var) in pairs(dict)]...]
end

export csestep

csestep(s::Sym, vars, dict) = s

csestep(s, vars, dict) = s

function csestep(x::S, vars, dict) where {S <: Symbolic}
    # Avoid breaking local variables out of their scope
    isempty(setdiff(atoms(x), vars)) || return x

    if !haskey(dict, x) 
        dict[x] = Sym{symtype(x)}(gensym())
    end

    return dict[x]
end

Here atoms returns the free variables (maybe I should call it that instead, but atoms is shorter). I need this because I need it to know that a Sum is built from something like a lambda term (index -> value), and the index should never escape the sum.

This used to work, but now it stops when it gets to sin. Any idea what's going on?

shashi commented 3 years ago

I'm not sure, try Chain instead of PassThrough.

cscherrer commented 3 years ago

Thanks, I fixed it a while back:

julia> cse(cos(cos(x)) + sin(cos(x)))
4-element Vector{Pair{Symbol, SymbolicUtils.Symbolic{Number}}}:
 Symbol("##473") => cos(x)
 Symbol("##474") => cos(var"##473")
 Symbol("##475") => sin(var"##473")
 Symbol("##476") => var"##474" + var"##475"

julia> cse(cos(cos(x)) + cos(cos(x)))
3-element Vector{Pair{Symbol, SymbolicUtils.Symbolic{Number}}}:
 Symbol("##477") => cos(x)
 Symbol("##478") => cos(var"##477")
 Symbol("##479") => 2var"##478"
dpsanders commented 3 years ago

Since I also have an implementation of this in ReversePropagation.jl, maybe we should pool our efforts and add this to Symbolics.jl?

cscherrer commented 3 years ago

Nice! I need to have another look at that package, I didn't realize you have CSE set up in it.

I'd love to have something like this for general-purpose use, as long as there's a way to represent free variables. For me this comes up in symbolic representation of sums, since you need to sharing the index variable doesn't really make sense. Ideally SymbolicUtils can have symbolic summations built in; my current implementation works but feels a little hacky.

dpsanders commented 3 years ago

@cscherrer Can you give an example of what you mean by symbolic summations?

cscherrer commented 3 years ago

Say you have this model:

julia> m = @model x begin
           a ~ Normal()
           b ~ Normal()
           y ~ For(1:100000) do j Normal(μ = a + b*x[j]) end
           return y
       end;

and some fake data,

julia> x = randn(100000);

julia> y = rand(m(x=x));

Then the posterior is

Sum(-0.5(((getindex(y, var"##i1#739")) - a - (b*(getindex(x, getindex(UnitRange(1, 100000), var"##i1#739")))))^2), var"##i1#739", 1, 100000) - (0.5(a^2)) - (0.5(b^2))

we really don't want to expand this sum, since we'd have a ridiculous number of terms. But we can still apply some rules and then do some constant folding, so we end up with

-0.5(121986.42933250747 + (10139.425041642466a) + (100001(a^2)) + (100587.87406561377(b^2)) + (262.72135043120505a*b) - (94008.24079371039b))

which is then really fast to sample from.

dpsanders commented 3 years ago

Thanks, that's interesting. I haven't thought about that situation at all. I've just been trying to get simple scalar expressions to work!

shashi commented 3 years ago

add this to Symbolics.jl?

Naah! It should be here. See #200

cscherrer commented 3 years ago

Do you have an approach for tracking variable scope, and Lambda or Func expressions? Otherwise CSE has lots of corner cases

shashi commented 3 years ago

The one in #200 is very conservative, it does not go inside a Func yet. And it outputs a self-contained Let.