EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
446 stars 63 forks source link

Hanging and task switch errors when differentiating DynamicExpressions.jl #1018

Closed MilesCranmer closed 8 months ago

MilesCranmer commented 1 year ago

In trying to get DynamicExpressions.jl to be compatible with Enzyme.jl, I've taken the approach of using Base.Cartesian.@nif over the space of operators – so that the compiler knows exactly what operators are being used at the forward pass. This fixes all of the remaining type stability issues.

However it seems like I am running into other issues with this change, in particular a StackOverflowError (or simply having the evaluation hang – which happens in this MWE).

Here's a MWE. I tried to reduce it more but I wasn't able to replicate the error, so unfortunately it is quite long. My intuition is that these nested Base.Cartesian.@nif (used for fusing operators) are too many branches for Enzyme to handle. What do you think?

using Enzyme

################################################################################
### OperatorEnum.jl
################################################################################
struct OperatorEnum{B,U}
    binops::B
    unaops::U
end
################################################################################

################################################################################
### Equation.jl
################################################################################
mutable struct Node{T}
    degree::UInt8  # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
    constant::Bool  # false if variable
    val::Union{T,Nothing}  # If is a constant, this stores the actual value
    # ------------------- (possibly undefined below)
    feature::UInt16  # If is a variable (e.g., x in cos(x)), this stores the feature index.
    op::UInt8  # If operator, this is the index of the operator in operators.binops, or operators.unaops
    l::Node{T}  # Left child node. Only defined for degree=1 or degree=2.
    r::Node{T}  # Right child node. Only defined for degree=2. 
    Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
    Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
    Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f))
    Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l)
    Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r)

end
function Node(::Type{T}; val::T1=nothing, feature::T2=nothing)::Node{T} where {T,T1,T2}
    if T2 <: Nothing
        !(T1 <: T) && (val = convert(T, val))
        return Node(T, 0, true, val)
    else
        return Node(T, 0, false, nothing, feature)
    end
end
Node(op::Integer, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)
Node(op::Integer, l::Node{T}, r::Node{T}) where {T} = Node(2, false, nothing, 0, op, l, r)
################################################################################

################################################################################
### Utils.jl
################################################################################
@inline function fill_similar(value, array, args...)
    out_array = similar(array, args...)
    out_array .= value
    return out_array
end
is_bad_array(array) = !(isempty(array) || isfinite(sum(array)))
# This macro is a way of conditionally using LoopVectorization.@turbo. Here we just
# leave it off as it interferes with Enzyme
macro maybe_turbo(flag, ex)
    return :($(esc(ex)))
end
function is_constant(tree::Node)
    if tree.degree == 0
        return tree.constant
    elseif tree.degree == 1
        return is_constant(tree.l)
    else
        return is_constant(tree.l) && is_constant(tree.r)
    end
end

################################################################################

################################################################################
### EvaluateEquation.jl
################################################################################
struct ResultOk{A<:AbstractArray}
    x::A
    ok::Bool
end

macro return_on_check(val, X)
    :(!isfinite($(esc(val))) && return $(ResultOk)(similar($(esc(X)), axes($(esc(X)), 2)), false))
end
macro return_on_nonfinite_array(array)
    :(is_bad_array($(esc(array))) && return $(ResultOk)($(esc(array)), false))
end

function eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false)) where {T<:Number}
    v_turbo = if isa(turbo, Val)
        turbo
    else
        turbo ? Val(true) : Val(false)
    end
    if v_turbo === Val(true)
        @assert T in (Float32, Float64)
    end

    result = _eval_tree_array(tree, cX, operators, v_turbo)
    return (result.x, result.ok && !is_bad_array(result.x))
end
function eval_tree_array(
    tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; kws...
) where {T1<:Number,T2<:Number}
    T = promote_type(T1, T2)
    @warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)."
    tree = convert(Node{T}, tree)
    cX = Base.Fix1(convert, T).(cX)
    return eval_tree_array(tree, cX, operators; kws...)
end

counttuple(::Type{<:NTuple{N,Any}}) where {N} = N
get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)
get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B)

@generated function _eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo})::ResultOk where {T<:Number,turbo}
    nuna = get_nuna(operators)
    nbin = get_nbin(operators)
    quote
        # First, we see if there are only constants in the tree - meaning
        # we can just return the constant result.
        if tree.degree == 0
            return deg0_eval(tree, cX)
        elseif is_constant(tree)
            # Speed hack for constant trees.
            const_result = _eval_constant_tree(tree, operators)::ResultOk{Vector{T}}
            !const_result.ok && return ResultOk(similar(cX, axes(cX, 2)), false)
            return ResultOk(fill_similar(const_result.x[], cX, axes(cX, 2)), true)
        elseif tree.degree == 1
            op_idx = tree.op
            # This @nif lets us generate an if statement over choice of operator,
            # which means the compiler will be able to completely avoid type inference on operators.
            return Base.Cartesian.@nif(
                $nuna,
                i -> i == op_idx,
                i -> let op = operators.unaops[i]
                    if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
                        # op(op2(x, y)), where x, y, z are constants or variables.
                        l_op_idx = tree.l.op
                        Base.Cartesian.@nif(
                            $nbin,
                            j -> j == l_op_idx,
                            j -> let op_l = operators.binops[j]
                                deg1_l2_ll0_lr0_eval(tree, cX, op, op_l, Val(turbo))
                            end,
                        )
                    elseif tree.l.degree == 1 && tree.l.l.degree == 0
                        # op(op2(x)), where x is a constant or variable.
                        l_op_idx = tree.l.op
                        Base.Cartesian.@nif(
                            $nuna,
                            j -> j == l_op_idx,
                            j -> let op_l = operators.unaops[j]
                                deg1_l1_ll0_eval(tree, cX, op, op_l, Val(turbo))
                            end,
                        )
                    else
                        # op(x), for any x.
                        result = _eval_tree_array(tree.l, cX, operators, Val(turbo))
                        !result.ok && return result
                        @return_on_nonfinite_array result.x
                        deg1_eval(result.x, op, Val(turbo))
                    end
                end
            )
        else
            # TODO - add op(op2(x, y), z) and op(x, op2(y, z))
            # op(x, y), where x, y are constants or variables.
            op_idx = tree.op
            return Base.Cartesian.@nif(
                $nbin,
                i -> i == op_idx,
                i -> let op = operators.binops[i]
                    if tree.l.degree == 0 && tree.r.degree == 0
                        deg2_l0_r0_eval(tree, cX, op, Val(turbo))
                    elseif tree.r.degree == 0
                        result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo))
                        !result_l.ok && return result_l
                        @return_on_nonfinite_array result_l.x
                        # op(x, y), where y is a constant or variable but x is not.
                        deg2_r0_eval(tree, result_l.x, cX, op, Val(turbo))
                    elseif tree.l.degree == 0
                        result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo))
                        !result_r.ok && return result_r
                        @return_on_nonfinite_array result_r.x
                        # op(x, y), where x is a constant or variable but y is not.
                        deg2_l0_eval(tree, result_r.x, cX, op, Val(turbo))
                    else
                        result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo))
                        !result_l.ok && return result_l
                        @return_on_nonfinite_array result_l.x
                        result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo))
                        !result_r.ok && return result_r
                        @return_on_nonfinite_array result_r.x
                        # op(x, y), for any x or y
                        deg2_eval(result_l.x, result_r.x, op, Val(turbo))
                    end
                end
            )
        end
    end
end

function deg2_eval(
    cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{turbo}
)::ResultOk where {T<:Number,F,turbo}
    @maybe_turbo turbo for j in eachindex(cumulator_l)
        x = op(cumulator_l[j], cumulator_r[j])::T
        cumulator_l[j] = x
    end
    return ResultOk(cumulator_l, true)
end

function deg1_eval(
    cumulator::AbstractVector{T}, op::F, ::Val{turbo}
)::ResultOk where {T<:Number,F,turbo}
    @maybe_turbo turbo for j in eachindex(cumulator)
        x = op(cumulator[j])::T
        cumulator[j] = x
    end
    return ResultOk(cumulator, true)
end

function deg0_eval(tree::Node{T}, cX::AbstractMatrix{T})::ResultOk where {T<:Number}
    if tree.constant
        return ResultOk(fill_similar(tree.val::T, cX, axes(cX, 2)), true)
    else
        return ResultOk(cX[tree.feature, :], true)
    end
end

function deg1_l2_ll0_lr0_eval(
    tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
) where {T<:Number,F,F2,turbo}
    if tree.l.l.constant && tree.l.r.constant
        val_ll = tree.l.l.val::T
        val_lr = tree.l.r.val::T
        @return_on_check val_ll cX
        @return_on_check val_lr cX
        x_l = op_l(val_ll, val_lr)::T
        @return_on_check x_l cX
        x = op(x_l)::T
        @return_on_check x cX
        return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
    elseif tree.l.l.constant
        val_ll = tree.l.l.val::T
        @return_on_check val_ll cX
        feature_lr = tree.l.r.feature
        cumulator = similar(cX, axes(cX, 2))
        @maybe_turbo turbo for j in axes(cX, 2)
            x_l = op_l(val_ll, cX[feature_lr, j])::T
            x = isfinite(x_l) ? op(x_l)::T : T(Inf)
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    elseif tree.l.r.constant
        feature_ll = tree.l.l.feature
        val_lr = tree.l.r.val::T
        @return_on_check val_lr cX
        cumulator = similar(cX, axes(cX, 2))
        @maybe_turbo turbo for j in axes(cX, 2)
            x_l = op_l(cX[feature_ll, j], val_lr)::T
            x = isfinite(x_l) ? op(x_l)::T : T(Inf)
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    else
        feature_ll = tree.l.l.feature
        feature_lr = tree.l.r.feature
        cumulator = similar(cX, axes(cX, 2))
        @maybe_turbo turbo for j in axes(cX, 2)
            x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
            x = isfinite(x_l) ? op(x_l)::T : T(Inf)
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end

# op(op2(x)) for x variable or constant
function deg1_l1_ll0_eval(
    tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
) where {T<:Number,F,F2,turbo}
    if tree.l.l.constant
        val_ll = tree.l.l.val::T
        @return_on_check val_ll cX
        x_l = op_l(val_ll)::T
        @return_on_check x_l cX
        x = op(x_l)::T
        @return_on_check x cX
        return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
    else
        feature_ll = tree.l.l.feature
        cumulator = similar(cX, axes(cX, 2))
        @maybe_turbo turbo for j in axes(cX, 2)
            x_l = op_l(cX[feature_ll, j])::T
            x = isfinite(x_l) ? op(x_l)::T : T(Inf)
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end

# op(x, y) for x and y variable/constant
function deg2_l0_r0_eval(
    tree::Node{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo}
) where {T<:Number,F,turbo}
    if tree.l.constant && tree.r.constant
        val_l = tree.l.val::T
        @return_on_check val_l cX
        val_r = tree.r.val::T
        @return_on_check val_r cX
        x = op(val_l, val_r)::T
        @return_on_check x cX
        return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
    elseif tree.l.constant
        cumulator = similar(cX, axes(cX, 2))
        val_l = tree.l.val::T
        @return_on_check val_l cX
        feature_r = tree.r.feature
        @maybe_turbo turbo for j in axes(cX, 2)
            x = op(val_l, cX[feature_r, j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    elseif tree.r.constant
        cumulator = similar(cX, axes(cX, 2))
        feature_l = tree.l.feature
        val_r = tree.r.val::T
        @return_on_check val_r cX
        @maybe_turbo turbo for j in axes(cX, 2)
            x = op(cX[feature_l, j], val_r)::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    else
        cumulator = similar(cX, axes(cX, 2))
        feature_l = tree.l.feature
        feature_r = tree.r.feature
        @maybe_turbo turbo for j in axes(cX, 2)
            x = op(cX[feature_l, j], cX[feature_r, j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end

# op(x, y) for x variable/constant, y arbitrary
function deg2_l0_eval(
    tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
) where {T<:Number,F,turbo}
    if tree.l.constant
        val = tree.l.val::T
        @return_on_check val cX
        @maybe_turbo turbo for j in eachindex(cumulator)
            x = op(val, cumulator[j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    else
        feature = tree.l.feature
        @maybe_turbo turbo for j in eachindex(cumulator)
            x = op(cX[feature, j], cumulator[j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end

# op(x, y) for x arbitrary, y variable/constant
function deg2_r0_eval(
    tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
) where {T<:Number,F,turbo}
    if tree.r.constant
        val = tree.r.val::T
        @return_on_check val cX
        @maybe_turbo turbo for j in eachindex(cumulator)
            x = op(cumulator[j], val)::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    else
        feature = tree.r.feature
        @maybe_turbo turbo for j in eachindex(cumulator)
            x = op(cumulator[j], cX[feature, j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end
@generated function _eval_constant_tree(tree::Node{T}, operators::OperatorEnum) where {T<:Number}
    nuna = get_nuna(operators)
    nbin = get_nbin(operators)
    quote
        if tree.degree == 0
            return deg0_eval_constant(tree)::ResultOk{Vector{T}}
        elseif tree.degree == 1
            op_idx = tree.op
            return Base.Cartesian.@nif(
                $nuna,
                i -> i == op_idx,
                i -> deg1_eval_constant(
                    tree, operators.unaops[i], operators
                )::ResultOk{Vector{T}}
            )
        else
            op_idx = tree.op
            return Base.Cartesian.@nif(
                $nbin,
                i -> i == op_idx,
                i -> deg2_eval_constant(
                    tree, operators.binops[i], operators
                )::ResultOk{Vector{T}}
            )
        end
    end
end
@inline function deg0_eval_constant(tree::Node{T}) where {T<:Number}
    output = tree.val::T
    return ResultOk([output], true)::ResultOk{Vector{T}}
end
function deg1_eval_constant(tree::Node{T}, op::F, operators::OperatorEnum) where {T<:Number,F}
    result = _eval_constant_tree(tree.l, operators)
    !result.ok && return result
    output = op(result.x[])::T
    return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}
end
function deg2_eval_constant(tree::Node{T}, op::F, operators::OperatorEnum) where {T<:Number,F}
    cumulator = _eval_constant_tree(tree.l, operators)
    !cumulator.ok && return cumulator
    result_r = _eval_constant_tree(tree.r, operators)
    !result_r.ok && return result_r
    output = op(cumulator.x[], result_r.x[])::T
    return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}
end
################################################################################

# Using the above to get gradients:

operators = OperatorEnum((+, -, *, /), (cos, sin))

x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3)

tree = Node(1, x1, Node(1, x2))  # == x1 + cos(x2)

X = randn(3, 100);
dX = zero(X)

eval_tree_array(tree, X, operators)

f(tree, X, operators, output) = (output[] = sum(eval_tree_array(tree, X, operators)[1]); nothing)

output = [0.0]
doutput = [1.0]

autodiff(
    Reverse,
    f,
    Const(tree),
    Duplicated(X, dX),
    Const(operators),
    Duplicated(output, doutput)
)

dX

Please ask for any clarification on this code. It's copied from DynamicExpressions.jl to be a self-contained example.

wsmoses commented 1 year ago

Can you also include the full backtrace/error?

On Sun, Aug 20, 2023 at 8:24 AM Miles Cranmer @.***> wrote:

In trying to get DynamicExpressions.jl to be compatible with Enzyme.jl, I've taken the approach of using @.*** over the space of operators – so that the compiler knows exactly what operators are being used at the forward pass. This fixes all of the remaining type stability issues.

However it seems like I am running into other issues with this change, in particular a StackOverflowError (or simply having the evaluation hang – which happens in this MWE).

Here's a MWE. I tried to reduce it more but I wasn't able to replicate the error, so unfortunately it is quite long. My intuition is that these nested @.*** (used for fusing operators) are too many branches for Enzyme to handle. What do you think?

using Enzyme ################################################################################### OperatorEnum.jl################################################################################struct OperatorEnum{B,U} binops::B unaops::Uend################################################################################ ################################################################################### Equation.jl################################################################################mutable struct Node{T} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. constant::Bool # false if variable val::Union{T,Nothing} # If is a constant, this stores the actual value

------------------- (possibly undefined below)

feature::UInt16  # If is a variable (e.g., x in cos(x)), this stores the feature index.
op::UInt8  # If operator, this is the index of the operator in operators.binops, or operators.unaops
l::Node{T}  # Left child node. Only defined for degree=1 or degree=2.
r::Node{T}  # Right child node. Only defined for degree=2.
Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f))
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l)
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r)

endfunction Node(::Type{T}; val::T1=nothing, feature::T2=nothing)::Node{T} where {T,T1,T2} if T2 <: Nothing !(T1 <: T) && (val = convert(T, val)) return Node(T, 0, true, val) else return Node(T, 0, false, nothing, feature) endendNode(op::Integer, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)Node(op::Integer, l::Node{T}, r::Node{T}) where {T} = Node(2, false, nothing, 0, op, l, r)################################################################################ ################################################################################### @. function fill_similar(value, array, args...) out_array = similar(array, args...) out_array .= value return out_arrayendis_bad_array(array) = !(isempty(array) || isfinite(sum(array)))# This macro is a way of conditionally using @. Here we just# leave it off as it interferes with Enzymemacro maybe_turbo(flag, ex) return :($(esc(ex)))end################################################################################ ################################################################################### EvaluateEquation.jl################################################################################

This struct just stores the result, and if there were any NaNs during evaluation:struct ResultOk{A<:AbstractArray}

x::A
ok::Boolend

macro return_on_check(val, X) :(!isfinite($(esc(val))) && return $(ResultOk)(similar($(esc(X)), axes($(esc(X)), 2)), false))endmacro return_on_nonfinite_array(array) :(is_bad_array($(esc(array))) && return $(ResultOk)($(esc(array)), false))end function eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false)) where {T<:Number} v_turbo = if isa(turbo, Val) turbo else turbo ? Val(true) : Val(false) end if v_turbo === Val(true) @assert T in (Float32, Float64) end

result = _eval_tree_array(tree, cX, operators, v_turbo)
return (result.x, result.ok && !is_bad_array(result.x))endfunction eval_tree_array(
tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; kws...

) where {T1<:Number,T2<:Number} T = promote_type(T1, T2) @warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)." tree = convert(Node{T}, tree) cX = Base.Fix1(convert, T).(cX) return eval_tree_array(tree, cX, operators; kws...)end counttuple(::Type{<:NTuple{N,Any}}) where {N} = Nget_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B) @generated function _eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo})::ResultOk where {T<:Number,turbo} nuna = get_nuna(operators) nbin = get_nbin(operators) quote

First, we see if there are only constants in the tree - meaning

    # we can just return the constant result.
    if tree.degree == 0
        return deg0_eval(tree, cX)
    elseif is_constant(tree)
        # Speed hack for constant trees.
        const_result = _eval_constant_tree(tree, operators)::ResultOk{Vector{T}}
        !const_result.ok && return ResultOk(similar(cX, axes(cX, 2)), false)
        return ResultOk(fill_similar(const_result.x[], cX, axes(cX, 2)), true)
    elseif tree.degree == 1
        op_idx = tree.op
        # This @nif lets us generate an if statement over choice of operator,
        # which means the compiler will be able to completely avoid type inference on operators.
        return ***@***.***(
            $nuna,
            i -> i == op_idx,
            i -> let op = operators.unaops[i]
                if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
                    # op(op2(x, y)), where x, y, z are constants or variables.
                    l_op_idx = tree.l.op
                    ***@***.***(
                        $nbin,
                        j -> j == l_op_idx,
                        j -> let op_l = operators.binops[j]
                            deg1_l2_ll0_lr0_eval(tree, cX, op, op_l, Val(turbo))
                        end,
                    )
                elseif tree.l.degree == 1 && tree.l.l.degree == 0
                    # op(op2(x)), where x is a constant or variable.
                    l_op_idx = tree.l.op
                    ***@***.***(
                        $nuna,
                        j -> j == l_op_idx,
                        j -> let op_l = operators.unaops[j]
                            deg1_l1_ll0_eval(tree, cX, op, op_l, Val(turbo))
                        end,
                    )
                else
                    # op(x), for any x.
                    result = _eval_tree_array(tree.l, cX, operators, Val(turbo))
                    !result.ok && return result
                    @return_on_nonfinite_array result.x
                    deg1_eval(result.x, op, Val(turbo))
                end
            end
        )
    else
        # TODO - add op(op2(x, y), z) and op(x, op2(y, z))
        # op(x, y), where x, y are constants or variables.
        op_idx = tree.op
        return ***@***.***(
            $nbin,
            i -> i == op_idx,
            i -> let op = operators.binops[i]
                if tree.l.degree == 0 && tree.r.degree == 0
                    deg2_l0_r0_eval(tree, cX, op, Val(turbo))
                elseif tree.r.degree == 0
                    result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo))
                    !result_l.ok && return result_l
                    @return_on_nonfinite_array result_l.x
                    # op(x, y), where y is a constant or variable but x is not.
                    deg2_r0_eval(tree, result_l.x, cX, op, Val(turbo))
                elseif tree.l.degree == 0
                    result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo))
                    !result_r.ok && return result_r
                    @return_on_nonfinite_array result_r.x
                    # op(x, y), where x is a constant or variable but y is not.
                    deg2_l0_eval(tree, result_r.x, cX, op, Val(turbo))
                else
                    result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo))
                    !result_l.ok && return result_l
                    @return_on_nonfinite_array result_l.x
                    result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo))
                    !result_r.ok && return result_r
                    @return_on_nonfinite_array result_r.x
                    # op(x, y), for any x or y
                    deg2_eval(result_l.x, result_r.x, op, Val(turbo))
                end
            end
        )
    end
endend

function deg2_eval( cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{turbo} )::ResultOk where {T<:Number,F,turbo} @maybe_turbo turbo for j in eachindex(cumulator_l) x = op(cumulator_l[j], cumulator_r[j])::T cumulator_l[j] = x end return ResultOk(cumulator_l, true)end function deg1_eval( cumulator::AbstractVector{T}, op::F, ::Val{turbo} )::ResultOk where {T<:Number,F,turbo} @maybe_turbo turbo for j in eachindex(cumulator) x = op(cumulator[j])::T cumulator[j] = x end return ResultOk(cumulator, true)end function deg0_eval(tree::Node{T}, cX::AbstractMatrix{T})::ResultOk where {T<:Number} if tree.constant return ResultOk(fill_similar(tree.val::T, cX, axes(cX, 2)), true) else return ResultOk(cX[tree.feature, :], true) endend function deg1_l2_ll0_lr0_eval( tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo} ) where {T<:Number,F,F2,turbo} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val::T val_lr = tree.l.r.val::T @return_on_check val_ll cX @return_on_check val_lr cX x_l = op_l(val_ll, val_lr)::T @return_on_check x_l cX x = op(x_l)::T @return_on_check x cX return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.l.constant val_ll = tree.l.l.val::T @return_on_check val_ll cX feature_lr = tree.l.r.feature cumulator = similar(cX, axes(cX, 2)) @maybe_turbo turbo for j in axes(cX, 2) x_l = op_l(val_ll, cX[feature_lr, j])::T x = isfinite(x_l) ? op(x_l)::T : T(Inf) cumulator[j] = x end return ResultOk(cumulator, true) elseif tree.l.r.constant feature_ll = tree.l.l.feature val_lr = tree.l.r.val::T @return_on_check val_lr cX cumulator = similar(cX, axes(cX, 2)) @maybe_turbo turbo for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j], val_lr)::T x = isfinite(x_l) ? op(x_l)::T : T(Inf) cumulator[j] = x end return ResultOk(cumulator, true) else feature_ll = tree.l.l.feature feature_lr = tree.l.r.feature cumulator = similar(cX, axes(cX, 2)) @maybe_turbo turbo for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T x = isfinite(x_l) ? op(x_l)::T : T(Inf) cumulator[j] = x end return ResultOk(cumulator, true) endend

op(op2(x)) for x variable or constantfunction deg1_l1_ll0_eval(

tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}

) where {T<:Number,F,F2,turbo} if tree.l.l.constant val_ll = tree.l.l.val::T @return_on_check val_ll cX x_l = op_l(val_ll)::T @return_on_check x_l cX x = op(x_l)::T @return_on_check x cX return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) else feature_ll = tree.l.l.feature cumulator = similar(cX, axes(cX, 2)) @maybe_turbo turbo for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j])::T x = isfinite(x_l) ? op(x_l)::T : T(Inf) cumulator[j] = x end return ResultOk(cumulator, true) endend

op(x, y) for x and y variable/constantfunction deg2_l0_r0_eval(

tree::Node{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo}

) where {T<:Number,F,turbo} if tree.l.constant && tree.r.constant val_l = tree.l.val::T @return_on_check val_l cX val_r = tree.r.val::T @return_on_check val_r cX x = op(val_l, val_r)::T @return_on_check x cX return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.constant cumulator = similar(cX, axes(cX, 2)) val_l = tree.l.val::T @return_on_check val_l cX feature_r = tree.r.feature @maybe_turbo turbo for j in axes(cX, 2) x = op(val_l, cX[feature_r, j])::T cumulator[j] = x end return ResultOk(cumulator, true) elseif tree.r.constant cumulator = similar(cX, axes(cX, 2)) feature_l = tree.l.feature val_r = tree.r.val::T @return_on_check val_r cX @maybe_turbo turbo for j in axes(cX, 2) x = op(cX[feature_l, j], val_r)::T cumulator[j] = x end return ResultOk(cumulator, true) else cumulator = similar(cX, axes(cX, 2)) feature_l = tree.l.feature feature_r = tree.r.feature @maybe_turbo turbo for j in axes(cX, 2) x = op(cX[feature_l, j], cX[feature_r, j])::T cumulator[j] = x end return ResultOk(cumulator, true) endend

op(x, y) for x variable/constant, y arbitraryfunction deg2_l0_eval(

tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}

) where {T<:Number,F,turbo} if tree.l.constant val = tree.l.val::T @return_on_check val cX @maybe_turbo turbo for j in eachindex(cumulator) x = op(val, cumulator[j])::T cumulator[j] = x end return ResultOk(cumulator, true) else feature = tree.l.feature @maybe_turbo turbo for j in eachindex(cumulator) x = op(cX[feature, j], cumulator[j])::T cumulator[j] = x end return ResultOk(cumulator, true) endend

op(x, y) for x arbitrary, y variable/constantfunction deg2_r0_eval(

tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}

) where {T<:Number,F,turbo} if tree.r.constant val = tree.r.val::T @return_on_check val cX @maybe_turbo turbo for j in eachindex(cumulator) x = op(cumulator[j], val)::T cumulator[j] = x end return ResultOk(cumulator, true) else feature = tree.r.feature @maybe_turbo turbo for j in eachindex(cumulator) x = op(cumulator[j], cX[feature, j])::T cumulator[j] = x end return ResultOk(cumulator, true) @. function _eval_constant_tree(tree::Node{T}, operators::OperatorEnum) where {T<:Number} nuna = get_nuna(operators) nbin = get_nbin(operators) quote if tree.degree == 0 return deg0_eval_constant(tree)::ResultOk{Vector{T}} elseif tree.degree == 1 op_idx = tree.op return @.( $nuna, i -> i == op_idx, i -> deg1_eval_constant( tree, operators.unaops[i], operators )::ResultOk{Vector{T}} ) else op_idx = tree.op return @.( $nbin, i -> i == op_idx, i -> deg2_eval_constant( tree, operators.binops[i], operators )::ResultOk{Vector{T}} ) end @. function deg0_eval_constant(tree::Node{T}) where {T<:Number} output = tree.val::T return ResultOk([output], true)::ResultOk{Vector{T}}endfunction deg1_eval_constant(tree::Node{T}, op::F, operators::OperatorEnum) where {T<:Number,F} result = _eval_constant_tree(tree.l, operators) !result.ok && return result output = op(result.x[])::T return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}endfunction deg2_eval_constant(tree::Node{T}, op::F, operators::OperatorEnum) where {T<:Number,F} cumulator = _eval_constant_tree(tree.l, operators) !cumulator.ok && return cumulator result_r = _eval_constant_tree(tree.r, operators) !result_r.ok && return result_r output = op(cumulator.x[], result_r.x[])::T return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}end################################################################################

Using the above to get gradients:

operators = OperatorEnum((+, -, *, /), (cos, sin))

x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3)

tree = x1

X = randn(3, 100); dX = zero(X) eval_tree_array(tree, X, operators) f(tree, X, operators, output) = (output[] = sum(eval_tree_array(tree, X, operators)[1]); nothing)

output = [0.0] doutput = [1.0] autodiff( Reverse, f, Const(tree), Duplicated(X, dX), Const(operators), Duplicated(output, doutput) )

dX

Please ask for any clarification on this code. It's copied from DynamicExpressions.jl to be a self-contained example.

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1018, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXAW3RSZOIXQZ3DFBADXWFDJNANCNFSM6AAAAAA3W3SWLM . You are receiving this because you are subscribed to this thread.Message ID: @.***>

MilesCranmer commented 1 year ago

Unfortunately there is none; it just exits immediately. Is there a way I can force it to log the stack overflow?

MilesCranmer commented 1 year ago

When it just hangs, and I force it to quit, I get this:

> julia mwe.jl
^C
[64807] signal (2): Interrupt: 2
in expression starting at /Users/mcranmer/Documents/DynamicExpressions.jl/mwe.jl:452
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
_ZN4llvm5Value12setValueNameEPNS_14StringMapEntryIPS0_EE at /Users/mcranmer/.julia/juliaup/julia-1.10.0-beta2+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib (unknown line)
unknown function (ip: 0x0)
Allocations: 18404280 (Pool: 18382207; Big: 22073); GC: 16
Error: Maybe we should never reach this?
wsmoses commented 1 year ago

If you have a version that has the stack overflow that is likely much easier for us to figure out the source of the issue (especially if that same error produced this behavior)

On Sun, Aug 20, 2023 at 8:32 AM Miles Cranmer @.***> wrote:

When it just hangs, and I force it to quit, I get this:

julia mwe.jl ^C [64807] signal (2): Interrupt: 2 in expression starting at /Users/mcranmer/Documents/DynamicExpressions.jl/mwe.jl:452 psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line) unknown function (ip: 0x0) psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line) unknown function (ip: 0x0) psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line) unknown function (ip: 0x0) psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line) unknown function (ip: 0x0) __psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line) unknown function (ip: 0x0) _ZN4llvm5Value12setValueNameEPNS_14StringMapEntryIPS0_EE at /Users/mcranmer/.julia/juliaup/julia-1.10.0-beta2+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib (unknown line) unknown function (ip: 0x0) Allocations: 18404280 (Pool: 18382207; Big: 22073); GC: 16 Error: Maybe we should never reach this?

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1018#issuecomment-1685131070, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXG3YF7VKPREMAWQUB3XWFEIVANCNFSM6AAAAAA3W3SWLM . You are receiving this because you commented.Message ID: <EnzymeAD/Enzyme. @.***>

MilesCranmer commented 1 year ago

Weirdly I can't get the specific StackOverflowError I saw earlier. I was able to get it by fixing some parts of the evaluation code but I forget how I did it... will try to find later.

However, I can get this task switching error, but it's not self-contained as the example above. It does include a stack trace though, so maybe it's of some use?

First, checkout the Enzyme-enabled SymbolicRegression.jl and DynamicExpressions.jl:

cd $(mktemp -d)

git clone git@github.com:SymbolicML/DynamicExpressions.jl.git
pushd DynamicExpressions.jl
git checkout avoid-iterated-outputs
popd

git clone git@github.com:MilesCranmer/SymbolicRegression.jl.git
pushd SymbolicRegression.jl
git checkout enzyme2
julia --project=. -e 'using Pkg; Pkg.add(path="../DynamicExpressions.jl"); Pkg.add("Enzyme")'

Then, in that same SymbolicRegression.jl project, the following code will trigger a task switch error:

using SymbolicRegression, Enzyme

X = randn(Float32, 5, 100)
y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 .- 2

options = SymbolicRegression.Options(;
    binary_operators=[+, *, /, -],
    unary_operators=[cos, exp],
    populations=20,
    enable_enzyme=true,
)

hall_of_fame = equation_search(
    X, y; niterations=40, options=options
)

Error:

ERROR: TaskFailedException
Stacktrace:
 [1] wait
   @ SymbolicRegression ./task.jl:352 [inlined]
 [2] fetch
   @ SymbolicRegression ./task.jl:372 [inlined]
 [3] _equation_search(::Val{…}, ::Val{…}, datasets::Vector{…}, niterations::Int64, options::Options{…}, numprocs::Nothing, procs::Nothing, addprocs_function::Nothing, runtests::Bool, saved_state::Nothing, verbosity::Int64, progress::Bool, ::Val{…})
   @ SymbolicRegression /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SymbolicRegression.jl:822
 [4] equation_search(datasets::Vector{…}; niterations::Int64, options::Options{…}, parallelism::Symbol, numprocs::Nothing, procs::Nothing, addprocs_function::Nothing, runtests::Bool, saved_state::Nothing, return_state::Nothing, verbosity::Nothing, progress::Nothing, v_dim_out::Val{…})
   @ SymbolicRegression /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SymbolicRegression.jl:505
 [5] equation_search(X::Matrix{…}, y::Matrix{…}; niterations::Int64, weights::Nothing, options::Options{…}, variable_names::Nothing, display_variable_names::Nothing, y_variable_names::Nothing, parallelism::Symbol, numprocs::Nothing, procs::Nothing, addprocs_function::Nothing, runtests::Bool, saved_state::Nothing, return_state::Nothing, loss_type::Type{…}, verbosity::Nothing, progress::Nothing, X_units::Nothing, y_units::Nothing, v_dim_out::Val{…}, multithreaded::Nothing, varMap::Nothing)
   @ SymbolicRegression /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SymbolicRegression.jl:383
 [6] equation_search
   @ SymbolicRegression /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SymbolicRegression.jl:328 [inlined]
 [7] #equation_search#24
   @ SymbolicRegression /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SymbolicRegression.jl:412 [inlined]
 [8] top-level scope
   @ REPL[5]:1

    nested task error: TaskFailedException
    Stacktrace:
     [1] wait
       @ SymbolicRegression ./task.jl:352 [inlined]
     [2] fetch
       @ SymbolicRegression ./task.jl:372 [inlined]
     [3] (::SymbolicRegression.var"#46#73"{Vector{Vector{Channel{Any}}}, Vector{Vector{Task}}, Int64, Int64})()
       @ SymbolicRegression /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SymbolicRegression.jl:795

        nested task error: task switch not allowed from inside staged nor pure functions
        Stacktrace:
          [1] try_yieldto(undo::typeof(Base.ensure_rescheduled))
            @ Base ./task.jl:921
          [2] wait()
            @ Base ./task.jl:995
          [3] wait(c::Base.GenericCondition{Base.Threads.SpinLock}; first::Bool)
            @ Base ./condition.jl:130
          [4] wait
            @ Base ./condition.jl:125 [inlined]
          [5] (::Base.var"#slowlock#647")(rl::ReentrantLock)
            @ Base ./lock.jl:156
          [6] lock(f::GPUCompiler.var"#8#10"{GPUCompiler.CompilerJob{…}}, l::ReentrantLock)
            @ Base ./lock.jl:147 [inlined]
          [7] cached_compilation
            @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9701 [inlined]
          [8] (::Enzyme.Compiler.var"#475#476"{…})(ctx::LLVM.Context)
            @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9768
          [9] JuliaContext(f::Enzyme.Compiler.var"#475#476"{…})
            @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:47
         [10] #s292#474
            @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9723 [inlined]
         [11]
            @ Enzyme.Compiler ./none:0
         [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
            @ Core ./boot.jl:600
         [13] autodiff
            @ SymbolicRegression.ConstantOptimizationModule ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:207 [inlined]
         [14] autodiff
            @ SymbolicRegression.ConstantOptimizationModule ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:236 [inlined]
         [15] autodiff
            @ SymbolicRegression.ConstantOptimizationModule ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:222 [inlined]
         [16] opt_func_g!
            @ SymbolicRegression.ConstantOptimizationModule /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/ext/SymbolicRegressionEnzymeExt.jl:26 [inlined]
         [17] (::SymbolicRegression.ConstantOptimizationModule.var"#g!#9"{…})(storage::Vector{…}, x::Vector{…})
            @ SymbolicRegression.ConstantOptimizationModule /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/ConstantOptimization.jl:101
         [18] (::NLSolversBase.var"#fg!#8"{…})(gx::Vector{…}, x::Vector{…})
            @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/objective_types/abstract.jl:13
         [19] value_gradient!!(obj::NLSolversBase.OnceDifferentiable{Float32, Vector{Float32}, Vector{Float32}}, x::Vector{Float32})
            @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
         [20] initial_state(method::Optim.BFGS{…}, options::Optim.Options{…}, d::NLSolversBase.OnceDifferentiable{…}, initial_x::Vector{…})
            @ Optim ~/.julia/packages/Optim/dBGGV/src/multivariate/solvers/first_order/bfgs.jl:94
         [21] optimize(d::D, initial_x::Tx, method::M, options::Optim.Options{…}, state::Any) where {…}
            @ Optim ~/.julia/packages/Optim/dBGGV/src/multivariate/optimize/optimize.jl:36 [inlined]
         [22] #optimize#91
            @ SymbolicRegression.ConstantOptimizationModule ~/.julia/packages/Optim/dBGGV/src/multivariate/optimize/interface.jl:156 [inlined]
         [23] optimize
            @ SymbolicRegression.ConstantOptimizationModule ~/.julia/packages/Optim/dBGGV/src/multivariate/optimize/interface.jl:151 [inlined]
         [24] _optimize_constants(dataset::Dataset{…}, member::PopMember{…}, options::Options{…}, algorithm::Optim.BFGS{…}, optimizer_options::Optim.Options{…}, idx::Nothing, ::Val{…})
            @ SymbolicRegression.ConstantOptimizationModule /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/ConstantOptimization.jl:108
         [25] call_opt
            @ SymbolicRegression.ConstantOptimizationModule /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/ConstantOptimization.jl:52 [inlined]
         [26] dispatch_optimize_constants(dataset::Dataset{…}, member::PopMember{…}, options::Options{…}, idx::Nothing)
            @ SymbolicRegression.ConstantOptimizationModule /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/ConstantOptimization.jl:71
         [27] optimize_constants
            @ SymbolicRegression.SingleIterationModule /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/ConstantOptimization.jl:43 [inlined]
         [28] optimize_and_simplify_population(dataset::Dataset{…}, pop::Population{…}, options::Options{…}, curmaxsize::Int64, record::Dict{…})
            @ SymbolicRegression.SingleIterationModule /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SingleIteration.jl:114
         [29] macro expansion
            @ SymbolicRegression /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SymbolicRegression.jl:754 [inlined]
         [30] (::SymbolicRegression.var"#44#71"{…})()
            @ SymbolicRegression /private/var/folders/1h/xyppkvx52cl6w3_h8bw_gdqh0000gr/T/tmp.NsK6zBkT/SymbolicRegression.jl/src/SearchUtils.jl:41
Some type information was truncated. Use `show(err)` to see complete types.
MilesCranmer commented 1 year ago

Anything else I can do to help with this?

One idea I had was to just comment out pieces of the evaluation kernel until it stops hanging. It seems like if it is only equal to the part that gets executed (deg1_eval), then there is no hang. It's something to do with compiling the other branches.

wsmoses commented 1 year ago

This requires @vchuravy for thoughts regarding the lock/unlock inside a generated function from here: https://github.com/EnzymeAD/Enzyme.jl/blob/628c9a4593efb455adf88aba981f8587e56da5c1/src/compiler.jl#L9743

wsmoses commented 1 year ago

I suppose we could write our own spinlock with some CAS's

MilesCranmer commented 1 year ago

Not sure if this is related but I wonder if that cached_compilation function should have a line like:

haskey(cache, key) && return cache[key]

before triggering the lock? If the key already exists in the cache, and there looks to be no code that changes an existing cache, then it seems like it might hit a lot of redundant locks?

wsmoses commented 1 year ago

So if it’s inside the locked region that means another thread has started compiling it but not completed yet.

On Thu, Aug 24, 2023 at 1:40 PM Miles Cranmer @.***> wrote:

Not sure if this is related but I wonder if that cached_compilation function should have a line like:

haskey(cache, key) && return cache[key]

before triggering the lock? If the key already exists in the cache, and there looks to be no code that changes an existing cache, then it seems like it might hit a lot of redundant locks?

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1018#issuecomment-1690989723, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXFKL3W7XU5JTQMNIITXW3LMHANCNFSM6AAAAAA3W3SWLM . You are receiving this because you commented.Message ID: <EnzymeAD/Enzyme. @.***>

MilesCranmer commented 1 year ago

i.e., something like this?

+    haskey(cache, key) && return cache[key]
     lock(cache_lock)
     try
+        haskey(cache, key) && return cache[key]  # if set after lock
+        asm = _thunk(job)
+        obj = _link(job, asm)
+        cache[key] = obj
-        obj = get(cache, key, nothing)
-        if obj === nothing
-            asm = _thunk(job)
-            obj = _link(job, asm)
-            cache[key] = obj
-        end
         obj
     finally
         unlock(cache_lock)
     end

(Also type stable this way)

MilesCranmer commented 1 year ago

Oops, should be this:

+    haskey(cache, key) && return cache[key]
     lock(cache_lock)
     try
+        if haskey(cache, key)  # if set after lock
+            cache[key]
+        else
+            asm = _thunk(job)
+            obj = _link(job, asm)
+            cache[key] = obj
+            obj
+       end
-        obj = get(cache, key, nothing)
-        if obj === nothing
-            asm = _thunk(job)
-            obj = _link(job, asm)
-            cache[key] = obj
-        end
-        obj
     finally
         unlock(cache_lock)
     end
wsmoses commented 1 year ago

Besides other issues regarding the need to lock, I don't think this will resolve the issue. Specifically the case you hit only happens if two threads ask for compilation at the same time, which wouldn't have the haskey return true.

MilesCranmer commented 1 year ago

I see, thanks. (Though I think that change might speed up compilation anyways? Otherwise it is locking the cache for reads, when you only need the lock for writes)

Back to the original MWE, I did find that commenting out these parts of the code fixed the issue:

                    if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
                        # op(op2(x, y)), where x, y, z are constants or variables.
                        l_op_idx = tree.l.op
                        Base.Cartesian.@nif(
                            $nbin,
                            j -> j == l_op_idx,
                            j -> let op_l = operators.binops[j]
                                deg1_l2_ll0_lr0_eval(tree, cX, op, op_l, Val(turbo))
                            end,
                        )
                    elseif tree.l.degree == 1 && tree.l.l.degree == 0
                        # op(op2(x)), where x is a constant or variable.
                        l_op_idx = tree.l.op
                        Base.Cartesian.@nif(
                            $nuna,
                            j -> j == l_op_idx,
                            j -> let op_l = operators.unaops[j]
                                deg1_l1_ll0_eval(tree, cX, op, op_l, Val(turbo))
                            end,
                        )

So maybe it's just an overly complex function (N if statements containing N if statements each – each describing a different kernel function)? That was making compilation take exponentially longer to the point it got stuck...

MilesCranmer commented 1 year ago

Okay the following two commits seemed to have fixed this issue on SymbolicRegression.jl. Basically I just turn off operator fusing when using Enzyme:

https://github.com/SymbolicML/DynamicExpressions.jl/pull/52/commits/13de9e9fd4d89dc1b495faad1225544b73ff74b9

and

https://github.com/MilesCranmer/SymbolicRegression.jl/pull/254/commits/5b9e7d3c2f197116c9b4e035143af0e5adc4af3f

MilesCranmer commented 1 year ago

Not sure if this is helpful, but here are some results of profiling a compilation run on a reduced evaluation kernel:

compilation_profiling.jlprof.zip

(open the .jlprof in ProfileView). I generated this by first calling autodiff, then making a tiny change to the function and running autodiff again to get the relevant time.

The flame graph looks like this (with ProfileView.view(C=true)):

Screenshot 2023-08-25 at 16 15 53

If we zoom in, pretty much the entirety is libLLVM calls, with that one peak being the Enzyme.jl stuff:

Screenshot 2023-08-25 at 16 17 46

so unfortunately I'm not really sure how to help speed things up, or where the inefficiencies are hiding that makes this branched code such a difficult thing for Enzyme to differentiate.

MilesCranmer commented 1 year ago

Here's a flat CSV version of the relevant calls:

flat_profile.csv

The top calls:

Count Overhead File Line Function
10959 10959 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? ZNK4llvm11Instruction24isIdenticalToWhenDefinedEPKS0
4161 4161 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZNK4llvm10BasicBlock19getFirstInsertionPtEv
3522 3522 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm16InstCombinerImpl12visitPHINodeERNS_7PHINodeE
2291 2291 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm22MustBeExecutedIterator7advanceEv
1957 1681 /usr/lib/system/libsystem_malloc.dylib ? tiny_free_no_lock
1545 1545 /usr/lib/system/libsystem_malloc.dylib ? free_tiny
1748 1402 /usr/lib/system/libsystem_malloc.dylib ? tiny_malloc_from_free_list
1318 1318 /usr/lib/system/libsystem_malloc.dylib ? tiny_malloc_should_clear
1295 1295 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZNK4llvm9LiveRange12overlapsFromERKS0_PKNS0_7SegmentE
1073 1073 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm17LiveIntervalUnion5Query23collectInterferingVRegsEj
1072 1072 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm11IntervalMapINS_9SlotIndexEPNS_12LiveIntervalELj8ENS_15IntervalMapInfoIS1_EEE14constiterator12pathFillFindES1
1064 1064 /usr/lib/system/libsystem_malloc.dylib ? tiny_free_list_add_ptr
1015 1015 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm8DenseMapINS_14PointerIntPairIPKNS_11InstructionELj1ENS_20ExplorationDirectionENS_21PointerLikeTypeTraitsIS4_EENS_18PointerIntPairInfoIS4_Lj1ES7_EEEENS_6detail13DenseSetEmptyENS_12DenseMapInfoISA_vEENSB_12DenseSetPairISA_EEE4growEj
767 767 /usr/lib/system/libsystem_malloc.dylib ? _szone_free
760 760 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm16FoldingSetNodeID10AddPointerEPKv
731 731 /Users/mcranmer/.julia/artifacts/def2892b4e5b4f69c074bd354565880dde51b72d/lib/libEnzyme-14.dylib ? _ZNSt316treeINS_12value_typeIKNS_6vectorIiNS_9allocatorIiEEEE12ConcreteTypeEENS_19map_value_compareIS6_S8_NS_4lessIS6_EELb1EEENS3_IS8_EEE12find_equalIS5_EERPNS_16__tree_node_baseIPvEENS_21tree_const_iteratorIS8_PNS_11tree_nodeIS8_SH_EElEERPNS_15tree_end_nodeISJ_EESKRKT
725 725 /usr/lib/system/libsystem_malloc.dylib ? free
716 716 /Users/mcranmer/.julia/artifacts/def2892b4e5b4f69c074bd354565880dde51b72d/lib/libEnzyme-14.dylib ? _ZNSt3127tree_balance_after_insertIPNS_16__tree_node_baseIPvEEEEvTS5
663 663 /usr/lib/system/libsystem_platform.dylib ? _platform_memmove
632 632 /usr/lib/system/libsystem_platform.dylib ? _platform_memcmp
571 571 /usr/lib/system/libsystem_malloc.dylib ? _malloc_zone_malloc
563 563 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm11IntervalMapINS_9SlotIndexEPNS_12LiveIntervalELj8ENS_15IntervalMapInfoIS1_EEE14constiterator4findES1
539 539 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm10FoldingSetINS_6SDNodeEE10NodeEqualsEPKNS_14FoldingSetBaseEPNS3_4NodeERKNS16FoldingSetNodeIDEjRS8
531 531 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm16FoldingSetNodeID10AddIntegerEj
518 518 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZNK4llvm6SDNode15hasNUsesOfValueEjj
514 514 /usr/lib/system/libsystem_malloc.dylib ? tiny_free_list_remove_ptr
1506 511 /Users/mcranmer/.julia/artifacts/def2892b4e5b4f69c074bd354565880dde51b72d/lib/libEnzyme-14.dylib ? _ZN8TypeTree11checkedOrInERKNSt3__16vectorIiNS0_9allocatorIiEEEE12ConcreteTypebRb
762 500 /Users/mcranmer/.julia/artifacts/def2892b4e5b4f69c074bd354565880dde51b72d/lib/libEnzyme-14.dylib ? _ZN8TypeTree6insertENSt3__16vectorIiNS0_9allocatorIiEEEE12ConcreteTypeb
492 492 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZN4llvm18RegPressureTracker20bumpDownwardPressureEPKNS_12MachineInstrE
462 462 /Users/mcranmer/.julia/juliaup/julia-1.9.3+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib ? _ZNK4llvm12DenseMapBaseINS_8DenseMapINSt3__14pairIPNS_5ValueEjEENS_19ValueLatticeElementENS_12DenseMapInfoIS6_vEENS_6detail12DenseMapPairIS6_S7_EEEES6_S7_S9_SC_E15LookupBucketForIS6_EEbRKTRPKSC

(note to self: generated with

julia> withenv("COLUMNS" => 300000) do 
           open("flat_profile.txt", "w") do io
               Profile.print(io, format=:flat, C=true, sortedby=:overhead, threads=1)
           end
       end

and then parsed with Excel)

MilesCranmer commented 1 year ago

Quick update:

Even with the changes mentioned above: https://github.com/SymbolicML/DynamicExpressions.jl/pull/52 and https://github.com/MilesCranmer/SymbolicRegression.jl/pull/254, which get 1st-order differentiation to stop hanging, now 2nd-order differentiation seems to have very very long compilation as well.

Here's an example (the MWE type system in my first comment should give this as well)

using DynamicExpressions
using Enzyme: autodiff, autodiff_deferred, Reverse, Const, Duplicated

const idx_r = 1
const idx_t = 2

function eval_and_sum!(buffer, tree, X, operators)
    # Turn off LoopVectorization.jl with `turbo=Val(false)`
    # Turn off excessively branched code with `fuse_level=Val(1)`
    buffer[] = sum(eval_tree_array(tree, X, operators; turbo=Val(false), fuse_level=Val(1))[1])
    return nothing
end

function compute_dX!(buffer, tree, X, operators, autodiff::F) where {F<:Function}
    tmp_dX = zero(X)
    output = [zero(eltype(X))]
    doutput = [one(eltype(X))]
    autodiff(
        Reverse,
        eval_and_sum!,
        Duplicated(output, doutput),
        Const(tree),
        Duplicated(X, tmp_dX),
        Const(operators),
    )
    buffer .= tmp_dX
    return nothing
end

function compute_∂χ(tree, X, operators, ::Val{idx}, autodiff::F) where {idx,F}
    dX = similar(X)
    compute_dX!(dX, tree, X, operators, autodiff)
    return view(dX, idx, :)
end

function compute_∂χ2(tree, X, operators, ::Val{idx}) where {idx}
    function f(buffer, tree, X, operators)
        buffer[] = sum(compute_∂χ(tree, X, operators, Val(idx), autodiff_deferred))
        return nothing
    end
    output = [zero(eltype(X))]
    doutput = [one(eltype(X))]
    d_∂X∂idx = zero(X)
    autodiff(
        Reverse,
        f,
        Duplicated(output, doutput),
        Const(tree),
        Duplicated(X, d_∂X∂idx),
        Const(operators),
    )
    return view(d_∂X∂idx, idx, :)
end

operators = OperatorEnum(
    binary_operators=(+, -, *, /),
    unary_operators=(cos,), 
)

N = 1000
r = rand(N) .* 100.0
t = rand(N) .* 200.0 .- 100.0

X = zeros(Float64, 2, N)
X[idx_r, :] .= r
X[idx_t, :] .= t

test_tree = let r = Node(Float64, feature=1), t = Node(Float64, feature=2)
    2.3 * r + 0.9 * t
end

The following calculation, which computes a 1st-order derivative, works fine:

∂χ_∂r = compute_∂χ(test_tree, X, operators, Val(idx_r), autodiff)

and successfully returns a vector of 2.3, which is correct as we have differentiated $2.3r + 0.9t$ with respect to r (i.e., idx_r = index of r).

Compilation is not too slow (though it still requires the changes I mentioned above to avoid hanging), and evaluation upon second calls is pretty fast.

However, computing the second order derivative seems to take an unreasonably long time (at least an hour before I quit):

∂χ_∂rr = compute_∂χ2(test_tree, X, operators, Val(idx_r))

(By the way... I'm not sure if any of this is useful so let me know how I can make myself helpful.)

wsmoses commented 1 year ago

The thing that would be most useful is finding the most minimal code (in terms of total LLVM instructions, which means more complicated julia code, like broadcasting, recursion, packages, etc will emit more LLVM, and ideally first derivative) that triggers the slowdown. That would allow us to try to figure out where it is being slow, why it is being slow, and hopefully fix it. I'm booked solid the next week and a half so won't have time to look at this until after then, but if you're able to get a truly minimal code in the interim, it'll allow me to have a chance of figuring out whats wrong when I have cycles.

MilesCranmer commented 1 year ago

Thanks.

A more minimal MWE might not be possible here as the core issue seems to be the number of branches the code can take. When I reduce the number of branches in the code to make the example smaller, the issue pretty much goes away... But with a few more branches, the compilation time gets exponentially longer. And second order derivatives seem to introduce yet another factor here. So I'm not really sure how to reduce things but still get the blowup in compilation times.

MilesCranmer commented 11 months ago

Moving to #1156 with a new MWE

wsmoses commented 8 months ago

Closing since have smaller issue