SymbolicML / DynamicExpressions.jl

Ridiculously fast symbolic expressions
https://ai.damtp.cam.ac.uk/dynamicexpressions
Apache License 2.0
106 stars 15 forks source link

Suppress Warning from ForwardDiff Duals #74

Open avik-pal opened 6 months ago

avik-pal commented 6 months ago

Turns out ForwardDiff.jl already works with DynamicExpressions, but we would want to turn off the type mismatch warnings.

using ForwardDiff, DynamicExpressions

operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos]);
x1 = Node(; feature=1)
x2 = Node(; feature=2)
expr = x1 * cos(x2 - 3.2)

X = rand(2, 5)

ForwardDiff.gradient(X) do X
    return sum(abs2, first(eval_tree_array(expr, X, operators)))
end
┌ Warning: Warning: eval_tree_array received mixed types: tree=Float32 and data=ForwardDiff.Dual{ForwardDiff.Tag{var"#13#14", Float64}, Float64, 10}.
└ @ DynamicExpressions.EvaluateModule /mnt/research/lux/DynamicExpressions.jl/src/Evaluate.jl:95
2×5 Matrix{Float64}:
  0.182636   0.882172   0.906966    0.161635    1.86929
 -0.020271  -0.363872  -0.0764073  -0.0126587  -0.331082
MilesCranmer commented 6 months ago

Awesome!

The warning indicates it is attempting to promote the value type of X as well as the AbstractExpressionNode. Converting the node can be exensive, especially if it’s a GraphNode, as it needs to make a copy of the whole tree.

I wonder if that means it’s also converting the Node{T} to Node{Dual{T}} and if that’s an issue at all. But if not an issue we could just disable the warning.