JuliaSymbolics / SymbolicUtils.jl

Symbolic expressions, rewriting and simplification
https://docs.sciml.ai/SymbolicUtils/stable/
Other
524 stars 99 forks source link

Stack overflow in custom interface #527

Open MilesCranmer opened 1 year ago

MilesCranmer commented 1 year ago

Hey all,

I am trying to build a direct interface in DynamicExpressions.jl to speed up simplification: https://github.com/SymbolicML/DynamicExpressions.jl/pull/42.

I am seeing a stack overflow at the moment, even for null rule sets. I am wondering if this may be because my type is constrained to binary trees ($\text{arity} \in {0, 1, 2}$)? Thus perhaps there is some initial expansion going on that fails.

Here's an example, using this commit: https://github.com/SymbolicML/DynamicExpressions.jl/pull/42/commits/c12d5a7672a7a7898cca106fee1a89a2accfbe80

using SymbolicUtils, DynamicExpressions

operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])

x1 = Node(Float64; feature=1)

expression = x1 + x1

simplify(SelfContainedNode(expression, operators), RuleSet([]))

which triggers the following error:

ERROR: StackOverflowError:
Stacktrace:
     [1] (::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false})(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}})
       @ SymbolicUtils.Rewriters ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:191
     [2] (::SymbolicUtils.Rewriters.PassThrough{SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false}})(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}})
       @ SymbolicUtils.Rewriters ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:188
     [3] iterate
       @ ./generator.jl:47 [inlined]
     [4] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, SymbolicUtils.Rewriters.PassThrough{SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
       @ Base ./array.jl:802
     [5] collect_similar
       @ ./array.jl:711 [inlined]
     [6] map
       @ ./abstractarray.jl:3261 [inlined]
     [7] (::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false})(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}})
       @ SymbolicUtils.Rewriters ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:198
--- the last 6 lines are repeated 12138 more times ---
 [72836] macro expansion
       @ ~/.julia/packages/SymbolicUtils/H684H/src/utils.jl:11 [inlined]
 [72837] (::SymbolicUtils.Rewriters.Fixpoint{SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false}})(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}})
       @ SymbolicUtils.Rewriters ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:122
 [72838] PassThrough
       @ ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:188 [inlined]
 [72839] simplify(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}}; expand::Bool, polynorm::Nothing, threaded::Bool, simplify_fractions::Bool, thread_subtree_cutoff::Int64, rewriter::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false})
       @ SymbolicUtils ~/.julia/packages/SymbolicUtils/H684H/src/simplify.jl:41
 [72840] simplify
       @ ~/.julia/packages/SymbolicUtils/H684H/src/simplify.jl:16 [inlined]
 [72841] simplify(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}}, ctx::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
       @ SymbolicUtils ./deprecated.jl:105
 [72842] simplify(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}}, ctx::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false})
       @ SymbolicUtils ./deprecated.jl:103
 [72843] #simplify#22
       @ ~/Documents/DynamicExpressions.jl/ext/DynamicExpressionsSymbolicUtilsExt.jl:352 [inlined]

My interface is as follows (full code here):

arity(x::SelfContainedNode) = x.tree.degree
istree(x::SelfContainedNode) = arity(x) > 0
symtype(::S) where {T,S<:SelfContainedNode{T}} = T
function operation(x::SelfContainedNode)
    if arity(x) == 1
        return x.operators.unaops[x.tree.op]
    elseif arity(x) == 2
        return x.operators.binops[x.tree.op]
    else
        error("Unexpected arity $(arity(x)).")
    end
end
function unsorted_arguments(x::S) where {T,S<:SelfContainedNode{T}}
    if arity(x) == 0
        return Any[]
    elseif arity(x) == 1
        return Any[isconstant(x.tree.l) ? x.tree.l.val::T : S(x.tree.l, x.operators)]
    elseif arity(x) == 2
        return Any[
            isconstant(x.tree.l) ? x.tree.l.val::T : S(x.tree.l, x.operators),
            isconstant(x.tree.r) ? x.tree.r.val::T : S(x.tree.r, x.operators),
        ]
    end
end
function arguments(x::S) where {T,S<:SelfContainedNode{T}}
    return unsorted_arguments(x)
end
function similarterm(
    t::S, f::F, args::AbstractArray, symtype=nothing; kws...
)::S where {T,S<:SelfContainedNode{T},F<:Function}
    if length(args) > 2
        l = similarterm(t, f, args[begin:(begin + 1)], symtype; kws...)
        return similarterm(t, f, [l, args[(begin + 2):end]...], symtype; kws...)
    end
    if length(args) == 1
        op_index = mustfindfirst(f, t.operators.unaops)
        new_node = Node(op_index, to_node(T, op_index, args[1]))
        return S(new_node, t.operators)
    elseif length(args) == 2
        op_index = mustfindfirst(f, t.operators.binops)
        new_node = if all(isconstant, args)
            to_node(T, op_index, f(args...))
        else
            Node(op_index, [to_node(T, op_index, arg) for arg in args]...)
        end
        return S(new_node, t.operators)
    else
        error("Unexpected length $(length(args)).")
    end
end

Basically SelfContainedNode stores a Node (binary tree, with 1-ary nodes allowed - see type description here) and OperatorEnum. For x::Node, x.l is the left child, and x.r the right child. x.op indexes the OperatorEnum.

My guess is that the first part of similarterm is triggering infinite recursion:

    if length(args) > 2
        l = similarterm(t, f, args[begin:(begin + 1)], symtype; kws...)
        return similarterm(t, f, [l, args[(begin + 2):end]...], symtype; kws...)
    end

This is required because a Node can only store two children at a time, so we recursively generate the tree here. Perhaps this breaks some assumption in SymbolicUtils.jl?

This might just be incompatible with the package, so feel free to close if so.