Open shashi opened 3 years ago
So, this is essentially SSA form, right? What if we made a struct SSATerm
(or CSETerm
if you prefer) that behaves as if it were a Term
but actually stores it's contents in this manner.
This would certainly make interop with things like Mjolnir.jl and IRTools.jl easier.
behaves as if it were a
Term
The interface for this is:
operation(t)::Function
arguments(t)::Vector
I can imagine operation(t)
being something like function block end
.
But arguments(t)
will need to contain assignment operations. So I'm not quite sure how to represent assignments in the world view of terms. I don't think =
as the head of the term captures the same meaning as changing an environment.
But, it's perfectly possible to convert this thing to a Term
when needed.
So I think we can get most of what we need by adding this as a top-level feature separate from terms, but with the opt-in conversion. Call it BasicBlock
or something. And when we have the ability to do conditionals, we can add more such nodes and compose BasicBlock
s to form more complex pieces of computation.
Hi @shashi , I had modified this to work in my context, but it's no longer working. Could something have changed in the last merge that broke it?
My setup is like this:
function cse(s::Symbolic)
vars = atoms(s)
dict = OrderedDict()
r = @rule ~x => csestep(~x, vars, dict)
final = RW.Postwalk(RW.PassThrough(r))(s)
[[var=>ex for (ex, var) in pairs(dict)]...]
end
export csestep
csestep(s::Sym, vars, dict) = s
csestep(s, vars, dict) = s
function csestep(x::S, vars, dict) where {S <: Symbolic}
# Avoid breaking local variables out of their scope
isempty(setdiff(atoms(x), vars)) || return x
if !haskey(dict, x)
dict[x] = Sym{symtype(x)}(gensym())
end
return dict[x]
end
Here atoms
returns the free variables (maybe I should call it that instead, but atoms is shorter). I need this because I need it to know that a Sum
is built from something like a lambda term (index -> value), and the index should never escape the sum.
This used to work, but now it stops when it gets to sin
. Any idea what's going on?
I'm not sure, try Chain
instead of PassThrough
.
Thanks, I fixed it a while back:
julia> cse(cos(cos(x)) + sin(cos(x)))
4-element Vector{Pair{Symbol, SymbolicUtils.Symbolic{Number}}}:
Symbol("##473") => cos(x)
Symbol("##474") => cos(var"##473")
Symbol("##475") => sin(var"##473")
Symbol("##476") => var"##474" + var"##475"
julia> cse(cos(cos(x)) + cos(cos(x)))
3-element Vector{Pair{Symbol, SymbolicUtils.Symbolic{Number}}}:
Symbol("##477") => cos(x)
Symbol("##478") => cos(var"##477")
Symbol("##479") => 2var"##478"
Since I also have an implementation of this in ReversePropagation.jl, maybe we should pool our efforts and add this to Symbolics.jl?
Nice! I need to have another look at that package, I didn't realize you have CSE set up in it.
I'd love to have something like this for general-purpose use, as long as there's a way to represent free variables. For me this comes up in symbolic representation of sums, since you need to sharing the index variable doesn't really make sense. Ideally SymbolicUtils can have symbolic summations built in; my current implementation works but feels a little hacky.
@cscherrer Can you give an example of what you mean by symbolic summations?
Say you have this model:
julia> m = @model x begin
a ~ Normal()
b ~ Normal()
y ~ For(1:100000) do j Normal(μ = a + b*x[j]) end
return y
end;
and some fake data,
julia> x = randn(100000);
julia> y = rand(m(x=x));
Then the posterior is
Sum(-0.5(((getindex(y, var"##i1#739")) - a - (b*(getindex(x, getindex(UnitRange(1, 100000), var"##i1#739")))))^2), var"##i1#739", 1, 100000) - (0.5(a^2)) - (0.5(b^2))
we really don't want to expand this sum, since we'd have a ridiculous number of terms. But we can still apply some rules and then do some constant folding, so we end up with
-0.5(121986.42933250747 + (10139.425041642466a) + (100001(a^2)) + (100587.87406561377(b^2)) + (262.72135043120505a*b) - (94008.24079371039b))
which is then really fast to sample from.
Thanks, that's interesting. I haven't thought about that situation at all. I've just been trying to get simple scalar expressions to work!
add this to Symbolics.jl?
Naah! It should be here. See #200
Do you have an approach for tracking variable scope, and Lambda or Func expressions? Otherwise CSE has lots of corner cases
The one in #200 is very conservative, it does not go inside a Func yet. And it outputs a self-contained Let
.
done with @dpsanders :
Examples:
This issue is a good place to think about some API questions: