SymbolicML / DynamicExpressions.jl

Ridiculously fast symbolic expressions
https://symbolicml.org/DynamicExpressions.jl/dev
Apache License 2.0
103 stars 15 forks source link

`tree_map`? #23

Closed MilesCranmer closed 1 year ago

MilesCranmer commented 1 year ago

I realize a lot of functions could just be implemented as calls to a generic tree_map function. For example,

tree_map(t -> 1, tree; merge=max)

would calculate the depth of a tree. The “merge” function would be used to aggregate left/right child for binary nodes. For example,

tree_map(t -> 1, tree; merge=(+))

would count the total number of nodes. Meanwhile,

tree_map(tree; merge=(+)) do t
    Int(t.degree==2)
end

would count the number of binary operators. Then something like

tree_map(tree; merge=(l, r)->[l…, r…]) do t
    if t.degree != 0 || !t.constant
        return []
    end
    return [t.val]
end

would return all constants in a tree (in depth-first traversal order).


@Moelf would this have been helpful for writing that NYT puzzle solver? What do you think of the API?

@AlCap23 any comment?

Moelf commented 1 year ago

I'm not very familiar with common types of tree-based algorithms, but yeah I think functions that facilitate tree-talking should be useful for making custom loss function right?

MilesCranmer commented 1 year ago

My god, it is beautiful.

image

The best part about this is that some of these functions actually got faster after this refactor.

Moelf commented 1 year ago

yeah, compiler > human when it comes to optimizing code (un)fortunately.

btw _ -> 1 can be expressed as Returns(1) in Julia now

MilesCranmer commented 1 year ago

Yeah. And it's so much more readable now! Makes it easier to think of other functions to map over these objects.

btw _ -> 1 can be expressed as Returns(1) in Julia now

I'm probably doing something wrong but it seems slower for some reason:

julia> f(_) = 1
f (generic function with 1 method)

julia> @btime f(3.2)
  1.272 ns (0 allocations: 0 bytes)
1

julia> g = Returns(1)
Returns{Int64}(1)

julia> @btime g(3.2)
  24.865 ns (0 allocations: 0 bytes)
1
Moelf commented 1 year ago

yeah const g = ... right now g is global non-const.

MilesCranmer commented 1 year ago

Cool. Thanks!

MilesCranmer commented 1 year ago

@Moelf what do you think about if I were to overload Base.mapreduce, rather than define a custom tree_mapreduce? Since it has the exact same syntax as a regular mapreduce, perhaps it is suitable. The one difference is the merge function here takes in (parent, child_l, child_r), whereas a normal mapreduce's merge takes in (element1, element2, ...).

MilesCranmer commented 1 year ago

Why don't I just overload all collection functions... Then you could just iterate through a tree!

Moelf commented 1 year ago

hm, I'm trying to imagine if walking down a tree is ~ iteration, and if that would surprise people

MilesCranmer commented 1 year ago

I added some other collection functions including iterate https://github.com/SymbolicML/DynamicExpressions.jl/blob/8d296193eb6c3e2e6f922c00674de29640f6ac2a/src/tree_map.jl#L107-L115

The behavior is to traverse a tree depth-first, left to right, and return the current node at each step. So you can do:

for node in tree
    if node.degree == 1 && tree.op == 2
        my_operator_count += 1
    end
end

And it will work as you might expect. It’s a tiny bit slower than using the mapreduce because it allocates a stack of nodes, but I think it might be easier for users to write custom losses.

MilesCranmer commented 1 year ago

One thing that might make this more intuitive is have a no-op type conversion. For example:

node_stack = DepthFirstTraversal(tree) |> collect

or filtering:

constant_nodes = filter(t -> t.degree == 0 && t.constant, DepthFirstTraversal(tree))

or looping:

for node in DepthFirstTraversal(tree)
    # next node will be child, if this node has degree > 0
end

DepthFirstTraversal would wrap the Node type and define how the tree is iterated over. But otherwise it wouldn't do anything, and (hopefully) wouldn't affect the performance.

struct DepthFirstTraversal{N<:Node} <: AbstractTraversal
    x::N
end

In the future could also add other traversal strategies.

What do you think?

MilesCranmer commented 1 year ago

@odow we chatted about tree structures at one point - I'd love to hear your take on this sort of interface!

odow commented 1 year ago

It looks nice. We don't really use algorithms over trees in JuMP/MOI. The first step in our AD engine is to convert everything to a single topologically sorted tape so everything requires a linear pass.

The other issue I ran into was people constructing nested expressions that mean you can't use recursion.

A somewhat artificial example, but expressions like this cause trouble:

N = 1_000_000
x = [Variable() for _ in 1:N]
y = x[1]
for i in 2:N
    y = +(y, x[i])
end

(Overlook the fact that you could lift all the nodes into +(x...) etc. It's just an artificial example.)

MilesCranmer commented 1 year ago

I see, thanks. Indeed that look hard for recursion. We are lucky in this sense because we never have expressions with more than ~100 nodes or so, but in the future I definitely want to try some sort of stack-based evaluation.

By the way, unrelated but eventually I would love to build some sort of interface between JuMP/MOI and SymbolicRegression.jl. Maybe so you evolve symbolic models in a JuMP problem, or maybe so you could use JuMP inside an objective to optimize a symbolic expression found by SymbolicRegression.jl. At the very least it could be useful to build a converter between them, like the one we have for Symbolics.jl.

MilesCranmer commented 1 year ago

Sometimes Julia code optimization is so weird. On Julia 1.9-rc3, this function:

function tree_mapreduce(f_leaf::F1, f_branch::F2, op::G, tree::N; preserve_sharing::Bool=false, result_type::Type{RT}=Nothing) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
    preserve_sharing && return @with_memoization(_tree_mapreduce(f_leaf, f_branch, op, tree), IdDict{N,RT}())
    return _tree_mapreduce(f_leaf, f_branch, op, tree)
end

is 2x slower than this function (only change is commenting out the conditional return):

function tree_mapreduce(f_leaf::F1, f_branch::F2, op::G, tree::N; preserve_sharing::Bool=false, result_type::Type{RT}=Nothing) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
    # preserve_sharing && return @with_memoization(_tree_mapreduce(f_leaf, f_branch, op, tree), IdDict{N,RT}())
    return _tree_mapreduce(f_leaf, f_branch, op, tree)
end

even when preserve_sharing is set to false! This is true even if I annotate the return type, even with -O3, etc.

But I don't see this for 1.8.5.

Maybe the precompilation caching is breaking something about this inline macro?

MilesCranmer commented 1 year ago

Fixed in https://discourse.julialang.org/t/strange-performance-issue-on-1-9-0-rc3/98427/10. Was quite a subtle issue.

I have since merged the "tree as collections" in #27 to be included in v0.8.0 onwards. Thanks for the tips @Moelf @odow!