EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Segmentation Fault #514

Closed freddycct closed 1 year ago

freddycct commented 1 year ago

An example of a Recursive Neural Network on Enzyme (state-of-the-art LLVM Auto Diff).

using Enzyme

abstract type Node end

struct Leaf <: Node
    params::NamedTuple
    x::Vector{Float32}
end

function (n::Leaf)()::Tuple{Float32, Vector{Float32}}
    y = n.params.w2 * n.x .+ n.params.b2 |> first
    return y, n.x
end

struct Branch <: Node
    params::NamedTuple
    left::Node
    right::Node
end

function (n::Branch)()::Tuple{Float32, Vector{Float32}}
    y1, h1 = n.left()
    y2, h2 = n.right()
    h = n.params.w1 * vcat(h1, h2) .+ n.params.b1
    y3 = n.params.w2 * h .+ n.params.b2 |> first

    y = y1+y2+y3

    return y, h
end

function genTree(N::Int, K::Int, ps::Vector, p::Float64)::Node
    if rand() < p
        return Leaf(ps[rand(1:N)], rand(Float32, K))
    else
        return Branch(ps[rand(1:N)], genTree(N, K, ps, p), genTree(N, K, ps, p))
    end
end

function loss(t1::Node, t2::Node)
    return t1()[1] - t2()[1]
end

function mirror(n::Leaf, mm::Dict)
    return Leaf(mm[n.params], zero(n.x))
end

function mirror(n::Branch, mm::Dict)
    return Branch(mm[n.params], mirror(n.left, mm), mirror(n.right, mm))
end

function main()
    N = 10
    K = 5
    ps = map(1:N) do x
        (
            w1 = randn(Float32, (K,2*K)),
            b1 = randn(Float32, K),
            w2 = randn(Float32, (1,K)),
            b2 = randn(Float32)
        )
    end

    grads = map(1:N) do x
        (
            w1 = zeros(Float32, (K,2*K)),
            b1 = zeros(Float32, K),
            w2 = zeros(Float32, (1,K)),
            b2 = 0
        )
    end

    mm = Dict(p=>g for (p, g) in zip(ps, grads))

    # create a tree
    t1 = genTree(N, K, ps, 0.5)
    t2 = genTree(N, K, ps, 0.5)

    t1Grads = mirror(t1, mm)
    t2Grads = mirror(t2, mm)

    @show ll = loss(t1, t2)
    @show loss(t1Grads, t2Grads)

    autodiff(Reverse, loss, Duplicated(t1, t1Grads), Duplicated(t2, t2Grads))

    @show ll = loss(t1, t2)
end

main()

error

Illegal orIn: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer} right: {[-1]:Pointer, [-1,8]:Pointer, [-1,16]:Pointer} PointerIntSame=0
Assertion failed: (0 && "Performed illegal ConcreteType::orIn"), function orIn, file /workspace/srcdir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeTree.h, line 997.

signal (6): Abort trap: 6
in expression starting at /Users/freddy/Documents/projects/earley/enzyme_tree_mwe.jl:90
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 40874044 (Pool: 40834696; Big: 39348); GC: 45
zsh: abort      julia enzyme_tree_mwe.jl
wsmoses commented 1 year ago

On latest main you should now get a nicer error, which can be resolved by adding the flags below. However it does run with a segfault for me.

@vchuravy

using Enzyme
#Enzyme.API.printall!(true)
# Enzyme.API.printtype!(true)
Enzyme.API.strictAliasing!(false)
Enzyme.API.looseTypeAnalysis!(true)

abstract type Node end

struct Leaf <: Node
    params::NamedTuple
    x::Vector{Float32}
end

function (n::Leaf)()::Tuple{Float32, Vector{Float32}}
    y = n.params.w2 * n.x .+ n.params.b2 |> first
    return y, n.x
end

struct Branch <: Node
    params::NamedTuple
    left::Node
    right::Node
end

function (n::Branch)()::Tuple{Float32, Vector{Float32}}
    y1, h1 = n.left()
    y2, h2 = n.right()
    h = n.params.w1 * vcat(h1, h2) .+ n.params.b1
    y3 = n.params.w2 * h .+ n.params.b2 |> first

    y = y1+y2+y3

    return y, h
end

function genTree(N::Int, K::Int, ps::Vector, p::Float64)::Node
    if rand() < p
        return Leaf(ps[rand(1:N)], rand(Float32, K))
    else
        return Branch(ps[rand(1:N)], genTree(N, K, ps, p), genTree(N, K, ps, p))
    end
end

function loss(t1::Node, t2::Node)
    return t1()[1] - t2()[1]
end

function mirror(n::Leaf, mm::Dict)
    return Leaf(mm[n.params], zero(n.x))
end

function mirror(n::Branch, mm::Dict)
    return Branch(mm[n.params], mirror(n.left, mm), mirror(n.right, mm))
end

function main()
    N = 10
    K = 5
    ps = map(1:N) do x
        (
            w1 = randn(Float32, (K,2*K)),
            b1 = randn(Float32, K),
            w2 = randn(Float32, (1,K)),
            b2 = randn(Float32),
        )
    end

    grads = map(1:N) do x
        (
            w1 = zeros(Float32, (K,2*K)),
            b1 = zeros(Float32, K),
            w2 = zeros(Float32, (1,K)),
            b2 = zeros(Float32),
        )
    end

    mm = Dict(p=>g for (p, g) in zip(ps, grads))

    # create a tree
    t1 = genTree(N, K, ps, 0.5)
    t2 = genTree(N, K, ps, 0.5)

    t1Grads = mirror(t1, mm)
    t2Grads = mirror(t2, mm)

    @show ll = loss(t1, t2)
    @show loss(t1Grads, t2Grads)

    autodiff(Reverse, loss, Duplicated(t1, t1Grads), Duplicated(t2, t2Grads))

    @show ll = loss(t1, t2)
end

main()
freddycct commented 1 year ago

Thanks for attention!

take note there's a minor bug (originally from the MWE i posted)

use this instead

grads = map(1:N) do x
        (
            w1 = zeros(Float32, (K,2*K)),
            b1 = zeros(Float32, K),
            w2 = zeros(Float32, (1,K)),
            b2 = 0.0f0
        )
    end
wsmoses commented 1 year ago

Reducing:

using Enzyme

# Enzyme.API.printall!(true)
# Enzyme.API.printtype!(true)
Enzyme.API.strictAliasing!(false)
# Enzyme.API.looseTypeAnalysis!(true)

abstract type Node end

struct Leaf <: Node
    params::NamedTuple
    x::Vector{Float32}
end

function (n::Leaf)()::Tuple{Float32, Vector{Float32}}
    y = first(n.params.b2)
    return y, n.x
end

struct Branch <: Node
    right::Node
end

function (n::Branch)()::Tuple{Float32, Vector{Float32}}
    y2, h2 = n.right()
    return 0.f0, Float64[]
end

function loss(t1::Node)
    return t1()[1]
end

function main()
    ps = 
        (
            b2 = 1.0f0,
        )

    grads = 
        (
            b2 = 0.0f0,
        )

    # create a tree
    t1 = Branch(Branch(Leaf(ps, rand(Float32, 0))))
    @show t1

    t1Grads = Branch(Branch(Leaf(ps, zeros(Float32, 0))))

    @show ll = loss(t1)
    @show loss(t1Grads)

    autodiff(Reverse, loss, Duplicated(t1, t1Grads))
end

main()
wsmoses commented 1 year ago
using Enzyme

Enzyme.API.printall!(true)
# Enzyme.API.printtype!(true)
Enzyme.API.strictAliasing!(false)
# Enzyme.API.looseTypeAnalysis!(true)

struct Leaf
    params::NamedTuple
end

function LeafF(n::Leaf)::Float32
    y = first(n.params.b2)
    r = convert(Tuple{Float32}, (y,))
    return r[1]
end

struct Branch
    right::Union{Branch, Leaf}
end

@noinline function BranchF(n::Branch)::Float32
    res = if typeof(n.right) == Leaf
        r = n.right::Leaf
        LeafF(r)
    else
        r = n.right::Branch
        BranchF(r)
    end

    return res * res
end

function loss(t1)::Float32
    return BranchF(t1)
end

function main()
    ps = 
        (
            b2 = 1.0f0,
        )

    grads = 
        (
            b2 = 0.0f0,
        )

    # create a tree
    t1 = Branch(Leaf(ps))
    @show t1

    t1Grads = Branch(Leaf(ps))

    @show ll = loss(t1)
    @show loss(t1Grads)

    forward, pullback = Enzyme.Compiler.thunk(loss, nothing, Active, Tuple{Duplicated{Branch}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1))
    @show forward(Duplicated(t1, t1Grads))
    # @test forward(Active(2.0)) == (nothing,)
    # @test pullback(Active(2.0), 1.0, nothing) == (1.0,)

    # autodiff(Reverse, loss, Duplicated(t1, t1Grads))
end

main()
wsmoses commented 1 year ago
using Enzyme

Enzyme.API.printall!(true)
# Enzyme.API.printtype!(true)
Enzyme.API.strictAliasing!(false)
# Enzyme.API.looseTypeAnalysis!(true)

struct Leaf
    params::NamedTuple
end

function LeafF(n::Leaf)::Float32
    y = first(n.params.b2)
    r = convert(Tuple{Float32}, (y,))
    return r[1]
end

function main()
    ps = 
        (
            b2 = 1.0f0,
        )

    grads = 
        (
            b2 = 0.0f0,
        )

    # create a tree
    t1 = Leaf(ps)
    @show t1

    t1Grads = Leaf(ps)

    @show ll = LeafF(t1)
    @show LeafF(t1Grads)

    forward, pullback = Enzyme.Compiler.thunk(LeafF, nothing, Active, Tuple{Duplicated{Leaf}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1))
    res = forward(Duplicated(t1, t1Grads))
    @show length(res)
    flush(stdout)
    @show res[1]
    flush(stdout)
    # @test forward(Active(2.0)) == (nothing,)
    # @test pullback(Active(2.0), 1.0, nothing) == (1.0,)

    # autodiff(Reverse, loss, Duplicated(t1, t1Grads))
end

main()
wsmoses commented 1 year ago
using Enzyme

Enzyme.API.printall!(true)
# Enzyme.API.printtype!(true)
# Enzyme.API.strictAliasing!(false)

struct Leaf
    params::NamedTuple
    x::Vector{Float32}
end

function LeafF(n::Leaf)::Tuple{Float32}
    y = n.params.w2 * n.x |> first
    return (y,)
end

function loss(t1)
    return LeafF(t1)[1] 
end

function main()
    N = 10
    K = 5
    ps = map(1:N) do x
        (
            w2 = randn(Float32, (1,K)),
        )
    end

    grads = map(1:N) do x
        (
            w2 = zeros(Float32, (1,K)),
        )
    end

    mm = Dict(p=>g for (p, g) in zip(ps, grads))

    # create a tree
    t1 = Leaf(ps[rand(1:N)], rand(Float32, K))
    t1Grads = Leaf(ps[rand(1:N)], zeros(Float32, K))

    @show t1
    @show ll = loss(t1)
    @show loss(t1Grads)

    autodiff(Reverse, loss, Duplicated(t1, t1Grads))

    @show ll = loss(t1, t2)
end

signal (11): Segmentation fault
in expression starting at /home/wmoses/git/Enzyme.jl/orinerr2.jl:51
gc_mark_loop at /home/wmoses/git/Enzyme.jl/julia8/src/gc.c:2582
_jl_gc_collect at /home/wmoses/git/Enzyme.jl/julia8/src/gc.c:3098
ijl_gc_collect at /home/wmoses/git/Enzyme.jl/julia8/src/gc.c:3327
maybe_collect at /home/wmoses/git/Enzyme.jl/julia8/src/gc.c:903
jl_gc_pool_alloc_inner at /home/wmoses/git/Enzyme.jl/julia8/src/gc.c:1247
jl_gc_pool_alloc_noinline at /home/wmoses/git/Enzyme.jl/julia8/src/gc.c:1306
jl_gc_alloc_ at /home/wmoses/git/Enzyme.jl/julia8/src/julia_internal.h:369
jl_new_uninitialized_datatype at /home/wmoses/git/Enzyme.jl/julia8/src/datatype.c:96
inst_datatype_inner at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:1506
inst_type_w_ at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:1820
ijl_instantiate_unionall at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:1065
rename_unionall at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:497
unalias_unionall at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:752
subtype_unionall at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:763
subtype at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:1257
subtype_unionall at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:805
subtype at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:1254
exists_subtype at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:1391
forall_exists_subtype at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:1419
ijl_types_equal at /home/wmoses/git/Enzyme.jl/julia8/src/subtype.c:1973
typekey_eq at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:588
lookup_type_idx_linear at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:707
lookup_type at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:740
inst_datatype_inner at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:1434
inst_type_w_ at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:1820
ijl_instantiate_unionall at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:1065
ijl_apply_type at /home/wmoses/git/Enzyme.jl/julia8/src/jltypes.c:997
jl_f_apply_type at /home/wmoses/git/Enzyme.jl/julia8/src/builtins.c:1250
widenconst at ./compiler/typelattice.jl:335
jfptr_widenconst_13468 at /home/wmoses/git/Enzyme.jl/julia8/usr/lib/julia/sys-debug.so (unknown line)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2348
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
#261 at ./compiler/typeutils.jl:44
anymap at ./compiler/utilities.jl:43
argtypes_to_type at ./compiler/typeutils.jl:44 [inlined]
handle_const_call! at ./compiler/ssair/inlining.jl:1275
assemble_inline_todo! at ./compiler/ssair/inlining.jl:1410
ssa_inlining_pass! at ./compiler/ssair/inlining.jl:82
jfptr_ssa_inlining_passNOT._12623 at /home/wmoses/git/Enzyme.jl/julia8/usr/lib/julia/sys-debug.so (unknown line)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2348
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
run_passes at ./compiler/optimize.jl:539
optimize at ./compiler/optimize.jl:504 [inlined]
_typeinf at ./compiler/typeinfer.jl:257
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2386
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2386
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_ext at ./compiler/typeinfer.jl:967
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1000
typeinf_ext_toplevel at ./compiler/typeinfer.jl:996
jfptr_typeinf_ext_toplevel_15913 at /home/wmoses/git/Enzyme.jl/julia8/usr/lib/julia/sys-debug.so (unknown line)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2348
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
jl_apply at /home/wmoses/git/Enzyme.jl/julia8/src/julia.h:1839
jl_type_infer at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:319
jl_generate_fptr_impl at /home/wmoses/git/Enzyme.jl/julia8/src/jitlayers.cpp:319
jl_compile_method_internal at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2081
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2359
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
from_tape_type at /home/wmoses/git/Enzyme.jl/src/compiler.jl:3640
unknown function (ip: 0x7fb645fdc1b7)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2348
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
from_tape_type at /home/wmoses/git/Enzyme.jl/src/compiler.jl:3640
unknown function (ip: 0x7fb645fdc6d7)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2367
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
from_tape_type at /home/wmoses/git/Enzyme.jl/src/compiler.jl:3640
unknown function (ip: 0x7fb645fdb407)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2367
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
#159 at /home/wmoses/git/Enzyme.jl/src/compiler.jl:6345
#Builder#62 at /home/wmoses/.julia/packages/LLVM/WjSQG/src/irbuilder.jl:21
unknown function (ip: 0x7fb645f698e9)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2348
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
Builder at /home/wmoses/.julia/packages/LLVM/WjSQG/src/irbuilder.jl:18
#s841#158 at /home/wmoses/git/Enzyme.jl/src/compiler.jl:6323 [inlined]
#s841#158 at ./none:0
jl_fptr_args at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2128
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2348
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
GeneratedFunctionStub at ./boot.jl:582
jl_fptr_args at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2128
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2348
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
jl_apply at /home/wmoses/git/Enzyme.jl/julia8/src/julia.h:1839
jl_call_staged at /home/wmoses/git/Enzyme.jl/julia8/src/method.c:520
ijl_code_for_staged at /home/wmoses/git/Enzyme.jl/julia8/src/method.c:571
get_staged at ./compiler/utilities.jl:114
retrieve_code_info at ./compiler/utilities.jl:126 [inlined]
InferenceState at ./compiler/inferencestate.jl:284
typeinf_ext at ./compiler/typeinfer.jl:965
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1000
typeinf_ext_toplevel at ./compiler/typeinfer.jl:996
jfptr_typeinf_ext_toplevel_15913 at /home/wmoses/git/Enzyme.jl/julia8/usr/lib/julia/sys-debug.so (unknown line)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2348
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
jl_apply at /home/wmoses/git/Enzyme.jl/julia8/src/julia.h:1839
jl_type_infer at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:319
jl_generate_fptr_impl at /home/wmoses/git/Enzyme.jl/julia8/src/jitlayers.cpp:319
jl_compile_method_internal at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2081
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2359
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
jl_apply at /home/wmoses/git/Enzyme.jl/julia8/src/julia.h:1839
do_apply at /home/wmoses/git/Enzyme.jl/julia8/src/builtins.c:730
jl_f__apply_iterate at /home/wmoses/git/Enzyme.jl/julia8/src/builtins.c:738
AdjointThunk at /home/wmoses/git/Enzyme.jl/src/compiler.jl:6098
jl_fptr_args at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2128
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2367
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
jl_apply at /home/wmoses/git/Enzyme.jl/julia8/src/julia.h:1839
do_apply at /home/wmoses/git/Enzyme.jl/julia8/src/builtins.c:730
jl_f__apply_iterate at /home/wmoses/git/Enzyme.jl/julia8/src/builtins.c:738
common_interface_rev at /home/wmoses/git/Enzyme.jl/src/compiler.jl:723
unknown function (ip: 0x7fb645fda59b)
_jl_invoke at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2367
ijl_apply_generic at /home/wmoses/git/Enzyme.jl/julia8/src/gf.c:2549
unknown function (ip: 0x7fb8775af6f5)
unknown function (ip: 0x7fb8775ac67b)
unknown function (ip: 0x7fb6429c45ff)
Allocations: 84930699 (Pool: 84870249; Big: 60450); GC: 94
Segmentation fault (core dumped)
wsmoses commented 1 year ago
using Enzyme
using LinearAlgebra

Enzyme.API.printall!(true)
# Enzyme.API.printtype!(true)
# Enzyme.API.strictAliasing!(false)

struct Leaf
    params::NamedTuple
    x::Vector{Float32}
end

function gv(w2, nx)
    p3 = Float32[0.0f0]
    # return LinearAlgebra.mul!(p3, w2, nx)
    return LinearAlgebra.gemv!(p3, 'N', w2, nx, true, false)
end

function LeafF(n::Leaf)::Tuple{Float32}
    p3 = gv(n.params.w2, n.x)
    y = p3 |> first
    return (y,)
end

function loss(t1)
    return LeafF(t1)[1] 
end

function main()
    N = 10
    K = 5
    ps = 
        (
            w2 = ones(Float32, (1,K)),
        )

    grads = 
        (
            w2 = zeros(Float32, (1,K)),
        )

    # create a tree
    t1 = Leaf(ps, rand(Float32, K))
    t1Grads = Leaf(grads, zeros(Float32, K))

    @show t1
    @show ll = loss(t1)
    @show loss(t1Grads)

    autodiff(Reverse, loss, Duplicated(t1, t1Grads))

    @show ll = loss(t1, t2)
end

main()
wsmoses commented 1 year ago
using Enzyme
using LinearAlgebra

Enzyme.API.printall!(true)
# Enzyme.API.printtype!(true)
# Enzyme.API.strictAliasing!(false)

struct Leaf
    params::NamedTuple
    x::Vector{Float32}
end

function mgemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{T},
               α::Number=true, β::Number=false) where {T<:LinearAlgebra.BlasFloat}
    mA, nA = LinearAlgebra.lapack_size(tA, A)
    alpha, beta = promote(α, β, zero(T))
        return LinearAlgebra.BLAS.gemv!(tA, alpha, A, x, beta, y)
end

function gv(w2, nx)
    p3 = Float32[0.0f0]
    # return LinearAlgebra.mul!(p3, w2, nx)
    return mgemv!(p3, 'N', w2, nx, true, false)
end

function LeafF(n::Leaf)::Tuple{Float32}
    p3 = gv(n.params.w2, n.x)
    y = p3 |> first
    return (y,)
end

function loss(t1)
    return LeafF(t1)[1] 
end

function main()
    N = 10
    K = 5
    ps = 
        (
            w2 = ones(Float32, (1,K)),
        )

    grads = 
        (
            w2 = zeros(Float32, (1,K)),
        )

    # create a tree
    t1 = Leaf(ps, rand(Float32, K))
    t1Grads = Leaf(grads, zeros(Float32, K))

    @show t1
    @show ll = loss(t1)
    @show loss(t1Grads)

    autodiff(Reverse, loss, Duplicated(t1, t1Grads))

    @show ll = loss(t1, t2)
end

main()
freddycct commented 1 year ago

With type stability...

using Enzyme

abstract type Node end

struct Leaf <: Node
    i::Int
    x::Vector{Float32}
end

function (n::Leaf)(θ::Vector)::Tuple{Float32, Vector{Float32}}
    params = θ[n.i]
    y = params.w2 * n.x .+ params.b2 |> first
    return y, n.x
end

struct Branch <: Node
    i::Int
    left::Node
    right::Node
end

function (n::Branch)(θ::Vector)::Tuple{Float32, Vector{Float32}}
    y1, h1 = n.left(θ)
    y2, h2 = n.right(θ)
    params = θ[n.i]
    h = params.w1 * vcat(h1, h2) .+ params.b1
    y3 = params.w2 * h .+ params.b2 |> first

    y = y1+y2+y3

    return y, h
end

function genTree(N::Int, K::Int, ps::Vector, p::Float64)::Node
    if rand() < p
        return Leaf(rand(1:N), rand(Float32, K))
    else
        return Branch(rand(1:N), genTree(N, K, ps, p), genTree(N, K, ps, p))
    end
end

function loss(t1::Node, t2::Node, θ::Vector)
    return (t1(θ)[1] - t2(θ)[1])^2
end

function main()
    N = 10
    K = 5
    θ = map(1:N) do x
        (
            w1=randn(Float32, (K,2*K)),
            b1=randn(Float32, K),
            w2=randn(Float32, (1,K)),
            b2=randn(Float32)
        )
    end

    grads = map(1:N) do x
        (
            w1 = zeros(Float32, (K,2*K)),
            b1 = zeros(Float32, K),
            w2 = zeros(Float32, (1,K)),
            b2 = 0.0f0
        )
    end

    # create a tree
    t1 = genTree(N, K, θ, 0.5)
    t2 = genTree(N, K, θ, 0.5)

    @show loss(t1, t2, θ)

    autodiff(Reverse, loss, Const(t1), Const(t2), Duplicated(θ, grads))
end

main()

leads to this

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Fu1YT/src/utils.jl:35
loss(t1, t2, θ) = 38391.54f0
warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Fu1YT/src/utils.jl:35
warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Fu1YT/src/utils.jl:35
warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Fu1YT/src/utils.jl:35
ERROR: LoadError: Enzyme compilation failed.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia_Leaf_3920({ float, {} addrspace(10)* }* noalias nocapture nonnull writeonly sret({ float, {} addrspace(10)* }) align 8 dereferenceable(16) %0, [1 x {} addrspace(10)*]* noalias nocapture writeonly %1, { i64, {} addrspace(10)* } addrspace(11)* nocapture nofree nonnull readonly align 8 dereferenceable(16) %2, {} addrspace(10)* nonnull align 16 dereferenceable(40) %3) unnamed_addr #32 !dbg !939 {
top:
  %4 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !940
  %5 = bitcast i8* %4 to [1 x [1 x i64]]*, !enzyme_caststack !12
  %6 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !940
  %7 = bitcast i8* %6 to [1 x [1 x i64]]*, !enzyme_caststack !12
  %8 = call {}*** @julia.get_pgcstack() #36
  %9 = getelementptr inbounds { i64, {} addrspace(10)* }, { i64, {} addrspace(10)* } addrspace(11)* %2, i64 0, i32 0, !dbg !941
  %10 = load i64, i64 addrspace(11)* %9, align 8, !dbg !943, !tbaa !54, !invariant.load !12
  %11 = add i64 %10, -1, !dbg !943
  %12 = bitcast {} addrspace(10)* %3 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !943
  %13 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %12 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !943
  %14 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %13, i64 0, i32 1, !dbg !943
  %15 = load i64, i64 addrspace(11)* %14, align 8, !dbg !943, !tbaa !62, !range !58
  %16 = icmp ult i64 %11, %15, !dbg !943
  br i1 %16, label %idxend, label %oob, !dbg !943

L28:                                              ; preds = %pass
  %17 = addrspacecast {} addrspace(10)* %73 to {} addrspace(11)*, !dbg !944
  %18 = addrspacecast {} addrspace(10)* %66 to {} addrspace(11)*, !dbg !944
  %.not10 = icmp eq {} addrspace(11)* %17, %18, !dbg !944
  br i1 %.not10, label %L58, label %L31, !dbg !944

L31:                                              ; preds = %L28
  %19 = load i8, i8* inttoptr (i64 4698127368 to i8*), align 8, !dbg !953, !tbaa !54, !invariant.load !12
  %20 = and i8 %19, 8, !dbg !955
  %.not15.not = icmp eq i8 %20, 0, !dbg !955
  br i1 %.not15.not, label %L41, label %L58, !dbg !955

L41:                                              ; preds = %L31
  %21 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %17) #37, !dbg !957
  %22 = bitcast {}* %21 to i8**, !dbg !957
  %23 = load i8*, i8** %22, align 8, !dbg !957, !tbaa !104, !nonnull !12
  %24 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %18) #37, !dbg !957
  %25 = bitcast {}* %24 to i8**, !dbg !957
  %26 = load i8*, i8** %25, align 8, !dbg !957, !tbaa !104, !nonnull !12
  %.not18 = icmp eq i8* %23, %26, !dbg !960
  br i1 %.not18, label %L53, label %L58, !dbg !956

L53:                                              ; preds = %L41
  %27 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %66) #36, !dbg !963
  br label %L58, !dbg !956

L58:                                              ; preds = %L53, %L41, %L31, %L28
  %value_phi1 = phi {} addrspace(10)* [ %66, %L28 ], [ %27, %L53 ], [ %66, %L31 ], [ %66, %L41 ]
  %.not11 = icmp eq i64 %71, 0, !dbg !965
  br i1 %.not11, label %oob4, label %L107.lr.ph, !dbg !966

L107.lr.ph:                                       ; preds = %L58
  %28 = bitcast {} addrspace(10)* %value_phi1 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !968
  %29 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %28 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !968
  %30 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %29, i64 0, i32 1, !dbg !968
  %31 = load i64, i64 addrspace(11)* %30, align 8, !dbg !968, !tbaa !62, !range !58
  %.not13 = icmp eq i64 %31, 1, !dbg !972
  %32 = bitcast {} addrspace(10)* %value_phi1 to float addrspace(13)* addrspace(10)*
  %33 = addrspacecast float addrspace(13)* addrspace(10)* %32 to float addrspace(13)* addrspace(11)*
  %34 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %33, align 8, !tbaa !104, !nonnull !12
  %35 = bitcast {} addrspace(10)* %73 to float addrspace(13)* addrspace(10)*
  %36 = addrspacecast float addrspace(13)* addrspace(10)* %35 to float addrspace(13)* addrspace(11)*
  %37 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %36, align 8, !tbaa !104, !nonnull !12
  br i1 %.not13, label %L107.us.preheader, label %L107.preheader, !dbg !976

L107.preheader:                                   ; preds = %L107.lr.ph
  br label %L107, !dbg !976

L107.us.preheader:                                ; preds = %L107.lr.ph
  br label %L107.us, !dbg !976

L107.us:                                          ; preds = %L107.us.preheader, %L107.us
  %iv1 = phi i64 [ %iv.next2, %L107.us ], [ 0, %L107.us.preheader ]
  %iv.next2 = add nuw nsw i64 %iv1, 1, !dbg !977
  %38 = load float, float addrspace(13)* %34, align 4, !dbg !977, !tbaa !384
  %39 = fadd float %67, %38, !dbg !984
  %40 = getelementptr inbounds float, float addrspace(13)* %37, i64 %iv1, !dbg !987
  store float %39, float addrspace(13)* %40, align 4, !dbg !987, !tbaa !384
  %exitcond21.not = icmp eq i64 %iv.next2, %71, !dbg !988
  br i1 %exitcond21.not, label %L124.loopexit, label %L107.us, !dbg !976, !llvm.loop !989

L107:                                             ; preds = %L107.preheader, %L107
  %iv = phi i64 [ %iv.next, %L107 ], [ 0, %L107.preheader ]
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !977
  %41 = getelementptr inbounds float, float addrspace(13)* %34, i64 %iv, !dbg !977
  %42 = load float, float addrspace(13)* %41, align 4, !dbg !977, !tbaa !384
  %43 = fadd float %67, %42, !dbg !984
  %44 = getelementptr inbounds float, float addrspace(13)* %37, i64 %iv, !dbg !987
  store float %43, float addrspace(13)* %44, align 4, !dbg !987, !tbaa !384
  %exitcond.not = icmp eq i64 %iv.next, %71, !dbg !988
  br i1 %exitcond.not, label %L124.loopexit3, label %L107, !dbg !976, !llvm.loop !989

L113:                                             ; preds = %pass
  %45 = getelementptr inbounds [1 x [1 x i64]], [1 x [1 x i64]]* %7, i64 0, i64 0, i64 0, !dbg !990
  store i64 %77, i64* %45, align 8, !dbg !990, !tbaa !238
  %46 = addrspacecast [1 x [1 x i64]]* %7 to [1 x [1 x i64]] addrspace(11)*, !dbg !992
  %47 = addrspacecast [1 x [1 x i64]]* %5 to [1 x [1 x i64]] addrspace(11)*, !dbg !992
  %48 = call fastcc nonnull {} addrspace(10)* @julia_throwdm_3926([1 x [1 x i64]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %46, [1 x [1 x i64]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %47) #38, !dbg !992
  unreachable, !dbg !992

L124.loopexit:                                    ; preds = %L107.us
  br label %L124, !dbg !993

L124.loopexit3:                                   ; preds = %L107
  br label %L124, !dbg !993

L124:                                             ; preds = %L124.loopexit3, %L124.loopexit
  br i1 %.not11, label %oob4, label %idxend5, !dbg !993

oob:                                              ; preds = %top
  %49 = alloca i64, align 8, !dbg !943
  store i64 %10, i64* %49, align 8, !dbg !943
  %50 = addrspacecast {} addrspace(10)* %3 to {} addrspace(12)*, !dbg !943
  call void @ijl_bounds_error_ints({} addrspace(12)* %50, i64* noundef nonnull align 8 %49, i64 noundef 1) #39, !dbg !943
  unreachable, !dbg !943

idxend:                                           ; preds = %top
  %51 = bitcast {} addrspace(10)* %3 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)*, !dbg !943
  %52 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)* %51 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)*, !dbg !943
  %53 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)* %52, align 16, !dbg !943, !tbaa !104, !nonnull !12
  %54 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %53, i64 %11, !dbg !943
  %55 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %54, align 8, !dbg !943, !tbaa !384
  %56 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %55, 0, !dbg !943
  %.not = icmp eq {} addrspace(10)* %56, null, !dbg !943
  br i1 %.not, label %fail, label %pass, !dbg !943

fail:                                             ; preds = %idxend
  call void @ijl_throw({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 4700173888 to {}*) to {} addrspace(12)*)) #39, !dbg !943
  unreachable, !dbg !943

pass:                                             ; preds = %idxend
  %57 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %55, 2, !dbg !996
  %58 = getelementptr inbounds { i64, {} addrspace(10)* }, { i64, {} addrspace(10)* } addrspace(11)* %2, i64 0, i32 1, !dbg !996
  %59 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %58 unordered, align 8, !dbg !996, !tbaa !54, !invariant.load !12, !nonnull !12, !dereferenceable !167, !align !168
  %60 = bitcast {} addrspace(10)* %57 to {} addrspace(10)* addrspace(10)*, !dbg !997
  %61 = addrspacecast {} addrspace(10)* addrspace(10)* %60 to {} addrspace(10)* addrspace(11)*, !dbg !997
  %62 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %61, i64 3, !dbg !997
  %63 = bitcast {} addrspace(10)* addrspace(11)* %62 to i64 addrspace(11)*, !dbg !997
  %64 = load i64, i64 addrspace(11)* %63, align 8, !dbg !997, !tbaa !54, !range !58, !invariant.load !12
  %65 = call noalias nonnull {} addrspace(10)* @ijl_alloc_array_1d({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4698127312 to {}*) to {} addrspace(10)*), i64 %64) #36, !dbg !999
  %66 = call fastcc nonnull {} addrspace(10)* @julia_gemv__3932({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %65, i32 noundef zeroext 1308622848, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %57, {} addrspace(10)* nonnull align 16 dereferenceable(40) %59, i8 noundef zeroext 1, i8 noundef zeroext 0) #31, !dbg !1003
  %67 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %55, 3, !dbg !996
  %68 = bitcast {} addrspace(10)* %66 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !1005
  %69 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %68 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !1005
  %70 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %69, i64 0, i32 1, !dbg !1005
  %71 = load i64, i64 addrspace(11)* %70, align 8, !dbg !1005, !tbaa !62, !range !58
  %72 = getelementptr inbounds [1 x [1 x i64]], [1 x [1 x i64]]* %5, i64 0, i64 0, i64 0, !dbg !1009
  store i64 %71, i64* %72, align 8, !dbg !1009, !tbaa !238
  %73 = call nonnull {} addrspace(10)* @ijl_alloc_array_1d({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4698127312 to {}*) to {} addrspace(10)*), i64 %71) #36, !dbg !1011
  %74 = bitcast {} addrspace(10)* %73 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !1018
  %75 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %74 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !1018
  %76 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %75, i64 0, i32 1, !dbg !1018
  %77 = load i64, i64 addrspace(11)* %76, align 8, !dbg !1018, !tbaa !62, !range !58
  %.not9 = icmp eq i64 %77, %71, !dbg !1019
  br i1 %.not9, label %L28, label %L113, !dbg !992

oob4:                                             ; preds = %L124, %L58
  %78 = alloca i64, align 8, !dbg !993
  store i64 1, i64* %78, align 8, !dbg !993
  %79 = addrspacecast {} addrspace(10)* %73 to {} addrspace(12)*, !dbg !993
  call void @ijl_bounds_error_ints({} addrspace(12)* %79, i64* noundef nonnull align 8 %78, i64 noundef 1) #39, !dbg !993
  unreachable, !dbg !993

idxend5:                                          ; preds = %L124
  %80 = bitcast {} addrspace(10)* %73 to float addrspace(13)* addrspace(10)*, !dbg !993
  %81 = addrspacecast float addrspace(13)* addrspace(10)* %80 to float addrspace(13)* addrspace(11)*, !dbg !993
  %82 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %81, align 8, !dbg !993, !tbaa !104, !nonnull !12
  %83 = load float, float addrspace(13)* %82, align 4, !dbg !993, !tbaa !384
  %84 = insertvalue { float, {} addrspace(10)* } zeroinitializer, float %83, 0, !dbg !1023
  %85 = insertvalue { float, {} addrspace(10)* } %84, {} addrspace(10)* %59, 1, !dbg !1023
  %86 = getelementptr inbounds [1 x {} addrspace(10)*], [1 x {} addrspace(10)*]* %1, i64 0, i64 0, !dbg !1023
  store {} addrspace(10)* %59, {} addrspace(10)** %86, align 8, !dbg !1023
  store { float, {} addrspace(10)* } %85, { float, {} addrspace(10)* }* %0, align 8, !dbg !1023
  ret void, !dbg !1023
}

Illegal firstPointer, num: 16 q: {[]:Pointer, [0]:Float@float, [8]:Pointer, [8,0]:Pointer, [8,0,-1]:Float@float, [8,8]:Integer, [8,9]:Integer, [8,10]:Integer, [8,11]:Integer, [8,12]:Integer, [8,13]:Integer, [8,14]:Integer, [8,15]:Integer, [8,16]:Integer, [8,17]:Integer, [8,18]:Integer, [8,19]:Integer, [8,20]:Integer, [8,21]:Integer, [8,22]:Integer, [8,23]:Integer, [8,24]:Integer, [8,25]:Integer, [8,26]:Integer, [8,27]:Integer, [8,28]:Integer, [8,29]:Integer, [8,30]:Integer, [8,31]:Integer, [8,32]:Integer, [8,33]:Integer, [8,34]:Integer, [8,35]:Integer, [8,36]:Integer, [8,37]:Integer, [8,38]:Integer, [8,39]:Integer, [8,40]:Integer}
 at { float, {} addrspace(10)* }* %0 from   store { float, {} addrspace(10)* } %85, { float, {} addrspace(10)* }* %0, align 8, !dbg !203

Stacktrace:
 [1] Leaf
   @ ~/Documents/projects/earley/enzyme_tree_mwe_explicit.jl:13

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:3518
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/7ekWs/src/api.jl:123
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Leaf, Vector{NamedTuple{(:w1, :b1, :w2, :b2), Tuple{Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Float32}}}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{typeof(loss), Tuple{Const{Branch}, Const{Leaf}, Duplicated{Vector{NamedTuple{(:w1, :b1, :w2, :b2), Tuple{Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Float32}}}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Bool, returnPrimal::Bool, jlrules::Vector{String})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:4762
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Leaf, Vector{NamedTuple{(:w1, :b1, :w2, :b2), Tuple{Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Float32}}}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:5854
  [5] _thunk
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6321 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Leaf, Vector{NamedTuple{(:w1, :b1, :w2, :b2), Tuple{Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Float32}}}}}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6315
  [7] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6359
  [8] #s836#163
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6419 [inlined]
  [9] var"#s836#163"(F::Any, Fn::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ReturnPrimal::Any, ::Any, #unused#::Type, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
 [11] thunk
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6452 [inlined]
 [12] thunk (repeats 2 times)
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6445 [inlined]
 [13] autodiff(::EnzymeCore.ReverseMode, ::typeof(loss), ::Type{Active{Float32}}, ::Const{Branch}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/7ekWs/src/Enzyme.jl:199
 [14] autodiff(::EnzymeCore.ReverseMode, ::typeof(loss), ::Const{Branch}, ::Const{Leaf}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/7ekWs/src/Enzyme.jl:236
 [15] main()
    @ Main ~/Documents/projects/earley/enzyme_tree_mwe_explicit.jl:74
 [16] top-level scope
    @ ~/Documents/projects/earley/enzyme_tree_mwe_explicit.jl:77
freddycct commented 1 year ago

@wsmoses workaround didn't work, but thanks for suggesting.

ERROR: LoadError: Enzyme compilation failed due to illegal type analysis. Current scope: ; Function Attrs: mustprogress willreturn

using Enzyme

struct NodeLayer
    w₁::Matrix{Float32}
    b₁::Vector{Float32}
    w₂::Matrix{Float32}
    b₂::Float32
end

NodeLayer(K::Int) = NodeLayer(randn(Float32, (K,2*K)), randn(Float32, K), randn(Float32, (1,K)), randn(Float32))

function (nl::NodeLayer)(y::Ref{Float32}, x₁::AbstractArray, x₂::AbstractArray)::Vector{Float32}
    h₀ = vcat(x₁, x₂)
    h = nl.w₁ * h₀ .+ nl.b₁
    y[] = first(nl.w₂ * h .+ nl.b₂)
    return h
end

abstract type Node end

struct Leaf <: Node
    i::Int
    x::Vector{Float32}
end

function (n::Leaf)(y::Ref{Float32}, θ::Vector{NodeLayer})::Vector{Float32}
    nl = θ[n.i]
    return nl(y, n.x, n.x)
end

struct Branch <: Node
    i::Int
    left::Node
    right::Node
end

function (n::Branch)(y::Ref{Float32}, θ::Vector{NodeLayer})::Vector{Float32}
    y₁ = Ref{Float32}(0.0f0)
    y₂ = Ref{Float32}(0.0f0)
    y₃ = Ref{Float32}(0.0f0)

    h₁ = n.left(y₁, θ)
    h₂ = n.right(y₂, θ)

    nl = θ[n.i]
    h₃ = nl(y₃, h₁, h₂)

    y[] = y₁[] + y₂[] + y₃[]
    return h₃
end

function genTree(N::Int, K::Int, d::Int)::Node
    if d == 1
        return Leaf(rand(1:N), rand(Float32, K))
    else
        return Branch(rand(1:N), genTree(N, K, rand(1:d-1)), genTree(N, K, rand(1:d-1)))
    end
end

function loss(t₁::Node, t₂::Node, θ::Vector{NodeLayer})::Float32
    y₁ = Ref{Float32}(0.0f0)
    y₂ = Ref{Float32}(0.0f0)
    t₁(y₁, θ)
    t₂(y₂, θ)
    return sqrt((y₁[] - y₂[])^2)
end

function main()
    M = 64 # number of tasks
    N = 20 # size of the parameters
    D = 5 # depth of the trees
    K = 32

    θ = map(x->NodeLayer(K), 1:N) # these parameters make up the model returned by genTree
    grads = map(1:N) do x
        NodeLayer(zeros(Float32, (K,2*K)), zeros(Float32, K), zeros(Float32, (1,K)), 0.0f0)
    end

    println("start training")
    epoch = 0

    # create a tree
    t₁ = genTree(N, K, D)
    t₂ = genTree(N, K, D)

    @show loss(t₁, t₂, θ)

    autodiff(Reverse, loss, Const(t₁), Const(t₂), Duplicated(θ, grads))
end

main()

Full output here ...

ERROR: LoadError: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc nonnull {} addrspace(10)* @preprocess_julia_Branch_3463({ i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* nocapture nofree nonnull readonly align 8 dereferenceable(24) %0, {} addrspace(10)* nonnull writeonly align 4 dereferenceable(4) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) unnamed_addr #34 !dbg !878 {
top:
  %3 = call {}*** @julia.get_pgcstack() #35
  %4 = bitcast {}*** %3 to {}**
  %5 = getelementptr inbounds {}*, {}** %4, i64 -12
  %6 = getelementptr inbounds {}*, {}** %5, i64 14
  %7 = bitcast {}** %6 to i8**
  %8 = load i8*, i8** %7, align 8
  %9 = call noalias nonnull dereferenceable(32) dereferenceable_or_null(32) {} addrspace(10)* @jl_gc_alloc_typed(i8* %8, i64 32, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5027154608 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !879
  %10 = bitcast {} addrspace(10)* %9 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)*, !enzyme_caststack !12
  %11 = bitcast {}*** %3 to {}**
  %12 = getelementptr inbounds {}*, {}** %11, i64 -12
  %13 = getelementptr inbounds {}*, {}** %12, i64 14
  %14 = bitcast {}** %13 to i8**
  %15 = load i8*, i8** %14, align 8
  %16 = call noalias nonnull dereferenceable(32) dereferenceable_or_null(32) {} addrspace(10)* @jl_gc_alloc_typed(i8* %15, i64 32, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5027154608 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !879
  %17 = bitcast {} addrspace(10)* %16 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)*, !enzyme_caststack !12
  %18 = bitcast {}*** %3 to {}**
  %19 = getelementptr inbounds {}*, {}** %18, i64 -12
  %20 = getelementptr inbounds {}*, {}** %19, i64 14
  %21 = bitcast {}** %20 to i8**
  %22 = load i8*, i8** %21, align 8
  %23 = call noalias nonnull dereferenceable(32) dereferenceable_or_null(32) {} addrspace(10)* @jl_gc_alloc_typed(i8* %22, i64 32, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5027154608 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !879
  %24 = bitcast {} addrspace(10)* %23 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)*, !enzyme_caststack !12
  %current_task113 = getelementptr inbounds {}**, {}*** %3, i64 -12, !dbg !880
  %current_task1 = bitcast {}*** %current_task113 to {}**, !dbg !880
  %25 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 4, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4489596528 to {}*) to {} addrspace(10)*)) #36, !dbg !880
  %26 = bitcast {} addrspace(10)* %25 to float addrspace(10)*, !dbg !880
  store float 0.000000e+00, float addrspace(10)* %26, align 4, !dbg !880, !tbaa !76
  %27 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 4, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4489596528 to {}*) to {} addrspace(10)*)) #36, !dbg !883
  %28 = bitcast {} addrspace(10)* %27 to float addrspace(10)*, !dbg !883
  store float 0.000000e+00, float addrspace(10)* %28, align 4, !dbg !883, !tbaa !76
  %29 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 4, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4489596528 to {}*) to {} addrspace(10)*)) #36, !dbg !886
  %30 = bitcast {} addrspace(10)* %29 to float addrspace(10)*, !dbg !886
  store float 0.000000e+00, float addrspace(10)* %30, align 4, !dbg !886, !tbaa !76
  %31 = getelementptr inbounds { i64, {} addrspace(10)*, {} addrspace(10)* }, { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* %0, i64 0, i32 1, !dbg !889
  %32 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %31 unordered, align 8, !dbg !889, !tbaa !57, !invariant.load !12, !nonnull !12
  %33 = call {} addrspace(10)* @julia.typeof({} addrspace(10)* nonnull %32) #37, !dbg !890
  %.not = icmp eq {} addrspace(10)* %33, addrspacecast ({}* inttoptr (i64 5023942704 to {}*) to {} addrspace(10)*), !dbg !890
  br i1 %.not, label %L7, label %L10, !dbg !890

L7:                                               ; preds = %top
  %34 = bitcast {} addrspace(10)* %32 to { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(10)*, !dbg !890
  %35 = addrspacecast { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(10)* %34 to { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)*, !dbg !890
  %36 = call fastcc nonnull {} addrspace(10)* @julia_Branch_3463({ i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %35, {} addrspace(10)* noundef nonnull align 4 dereferenceable(4) %25, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) #33, !dbg !890
  br label %L21, !dbg !890

L10:                                              ; preds = %top
  %.not20 = icmp eq {} addrspace(10)* %33, addrspacecast ({}* inttoptr (i64 5023365680 to {}*) to {} addrspace(10)*), !dbg !890
  br i1 %.not20, label %L12, label %L19, !dbg !890

L12:                                              ; preds = %L10
  %37 = bitcast {} addrspace(10)* %32 to i64 addrspace(10)*, !dbg !891
  %38 = addrspacecast i64 addrspace(10)* %37 to i64 addrspace(11)*, !dbg !891
  %39 = load i64, i64 addrspace(11)* %38, align 8, !dbg !893, !tbaa !775
  %40 = add i64 %39, -1, !dbg !893
  %41 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !893
  %42 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %41 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !893
  %43 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %42, i64 0, i32 1, !dbg !893
  %44 = load i64, i64 addrspace(11)* %43, align 8, !dbg !893, !tbaa !63, !range !59
  %45 = icmp ult i64 %40, %44, !dbg !893
  br i1 %45, label %idxend10, label %oob9, !dbg !893

L19:                                              ; preds = %L10
  %46 = call cc37 nonnull {} addrspace(10)* bitcast ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic to {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* nonnull %32, {} addrspace(10)* nonnull align 4 dereferenceable(4) %25, {} addrspace(10)* nonnull %2) #35, !dbg !890
  br label %L21, !dbg !890

L21:                                              ; preds = %pass12, %L19, %L7
  %value_phi = phi {} addrspace(10)* [ %36, %L7 ], [ %117, %pass12 ], [ %46, %L19 ]
  %47 = getelementptr inbounds { i64, {} addrspace(10)*, {} addrspace(10)* }, { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* %0, i64 0, i32 2, !dbg !894
  %48 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %47 unordered, align 8, !dbg !894, !tbaa !57, !invariant.load !12, !nonnull !12
  %49 = call {} addrspace(10)* @julia.typeof({} addrspace(10)* nonnull %48) #37, !dbg !895
  %.not16 = icmp eq {} addrspace(10)* %49, addrspacecast ({}* inttoptr (i64 5023942704 to {}*) to {} addrspace(10)*), !dbg !895
  br i1 %.not16, label %L25, label %L28, !dbg !895

L25:                                              ; preds = %L21
  %50 = bitcast {} addrspace(10)* %48 to { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(10)*, !dbg !895
  %51 = addrspacecast { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(10)* %50 to { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)*, !dbg !895
  %52 = call fastcc nonnull {} addrspace(10)* @julia_Branch_3463({ i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %51, {} addrspace(10)* nonnull align 4 dereferenceable(4) %27, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) #33, !dbg !895
  br label %L39, !dbg !895

L28:                                              ; preds = %L21
  %.not18 = icmp eq {} addrspace(10)* %49, addrspacecast ({}* inttoptr (i64 5023365680 to {}*) to {} addrspace(10)*), !dbg !895
  br i1 %.not18, label %L30, label %L37, !dbg !895

L30:                                              ; preds = %L28
  %53 = bitcast {} addrspace(10)* %48 to i64 addrspace(10)*, !dbg !896
  %54 = addrspacecast i64 addrspace(10)* %53 to i64 addrspace(11)*, !dbg !896
  %55 = load i64, i64 addrspace(11)* %54, align 8, !dbg !898, !tbaa !775
  %56 = add i64 %55, -1, !dbg !898
  %57 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !898
  %58 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %57 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !898
  %59 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %58, i64 0, i32 1, !dbg !898
  %60 = load i64, i64 addrspace(11)* %59, align 8, !dbg !898, !tbaa !63, !range !59
  %61 = icmp ult i64 %56, %60, !dbg !898
  br i1 %61, label %idxend6, label %oob5, !dbg !898

L37:                                              ; preds = %L28
  %62 = call cc37 nonnull {} addrspace(10)* bitcast ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic to {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* nonnull %48, {} addrspace(10)* nonnull %27, {} addrspace(10)* nonnull %2) #35, !dbg !895
  br label %L39, !dbg !895

L39:                                              ; preds = %pass8, %L37, %L25
  %value_phi4 = phi {} addrspace(10)* [ %52, %L25 ], [ %103, %pass8 ], [ %62, %L37 ]
  %63 = getelementptr inbounds { i64, {} addrspace(10)*, {} addrspace(10)* }, { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* %0, i64 0, i32 0, !dbg !899
  %64 = load i64, i64 addrspace(11)* %63, align 8, !dbg !901, !tbaa !57, !invariant.load !12
  %65 = add i64 %64, -1, !dbg !901
  %66 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !901
  %67 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %66 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !901
  %68 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %67, i64 0, i32 1, !dbg !901
  %69 = load i64, i64 addrspace(11)* %68, align 8, !dbg !901, !tbaa !63, !range !59
  %70 = icmp ult i64 %65, %69, !dbg !901
  br i1 %70, label %idxend, label %oob, !dbg !901

oob:                                              ; preds = %L39
  %71 = alloca i64, align 8, !dbg !901
  store i64 %64, i64* %71, align 8, !dbg !901
  %72 = addrspacecast {} addrspace(10)* %2 to {} addrspace(12)*, !dbg !901
  call void @ijl_bounds_error_ints({} addrspace(12)* %72, i64* noundef nonnull align 8 %71, i64 noundef 1) #38, !dbg !901
  unreachable, !dbg !901

idxend:                                           ; preds = %L39
  %73 = bitcast {} addrspace(10)* %2 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)*, !dbg !901
  %74 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)* %73 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)*, !dbg !901
  %75 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)* %74, align 16, !dbg !901, !tbaa !128, !nonnull !12
  %76 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %75, i64 %65, !dbg !901
  %77 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %76, align 8, !dbg !901, !tbaa !650
  %78 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %77, 0, !dbg !901
  %.not17 = icmp eq {} addrspace(10)* %78, null, !dbg !901
  br i1 %.not17, label %fail, label %pass, !dbg !901

fail:                                             ; preds = %idxend
  call void @ijl_throw({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 4697798208 to {}*) to {} addrspace(12)*)) #38, !dbg !901
  unreachable, !dbg !901

pass:                                             ; preds = %idxend
  %.fca.0.gep31 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17, i64 0, i32 0, !dbg !902
  store {} addrspace(10)* %78, {} addrspace(10)* addrspace(10)* %.fca.0.gep31, align 8, !dbg !902
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %16, {} addrspace(10)* %78), !dbg !902
  %.fca.1.extract32 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %77, 1, !dbg !902
  %.fca.1.gep33 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17, i64 0, i32 1, !dbg !902
  store {} addrspace(10)* %.fca.1.extract32, {} addrspace(10)* addrspace(10)* %.fca.1.gep33, align 8, !dbg !902
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %16, {} addrspace(10)* %.fca.1.extract32), !dbg !902
  %.fca.2.extract34 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %77, 2, !dbg !902
  %.fca.2.gep35 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17, i64 0, i32 2, !dbg !902
  store {} addrspace(10)* %.fca.2.extract34, {} addrspace(10)* addrspace(10)* %.fca.2.gep35, align 8, !dbg !902
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %16, {} addrspace(10)* %.fca.2.extract34), !dbg !902
  %.fca.3.extract36 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %77, 3, !dbg !902
  %.fca.3.gep37 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17, i64 0, i32 3, !dbg !902
  store float %.fca.3.extract36, float addrspace(10)* %.fca.3.gep37, align 8, !dbg !902
  %79 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)*, !dbg !902
  %80 = call fastcc nonnull {} addrspace(10)* @julia_NodeLayer_3468({ {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(32) %79, {} addrspace(10)* nonnull align 4 dereferenceable(4) %29, {} addrspace(10)* nonnull align 16 dereferenceable(40) %value_phi, {} addrspace(10)* nonnull align 16 dereferenceable(40) %value_phi4) #33, !dbg !902
  %81 = addrspacecast float addrspace(10)* %26 to float addrspace(11)*, !dbg !903
  %82 = load float, float addrspace(11)* %81, align 4, !dbg !903, !tbaa !76
  %83 = addrspacecast float addrspace(10)* %28 to float addrspace(11)*, !dbg !903
  %84 = load float, float addrspace(11)* %83, align 4, !dbg !903, !tbaa !76
  %85 = addrspacecast float addrspace(10)* %30 to float addrspace(11)*, !dbg !903
  %86 = load float, float addrspace(11)* %85, align 4, !dbg !903, !tbaa !76
  %87 = fadd float %82, %84, !dbg !906
  %88 = fadd float %87, %86, !dbg !906
  %89 = bitcast {} addrspace(10)* %1 to float addrspace(10)*, !dbg !908
  store float %88, float addrspace(10)* %89, align 4, !dbg !908, !tbaa !76
  ret {} addrspace(10)* %80, !dbg !910

oob5:                                             ; preds = %L30
  %90 = alloca i64, align 8, !dbg !898
  store i64 %55, i64* %90, align 8, !dbg !898
  %91 = addrspacecast {} addrspace(10)* %2 to {} addrspace(12)*, !dbg !898
  call void @ijl_bounds_error_ints({} addrspace(12)* %91, i64* noundef nonnull align 8 %90, i64 noundef 1) #38, !dbg !898
  unreachable, !dbg !898

idxend6:                                          ; preds = %L30
  %92 = bitcast {} addrspace(10)* %2 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)*, !dbg !898
  %93 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)* %92 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)*, !dbg !898
  %94 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)* %93, align 16, !dbg !898, !tbaa !128, !nonnull !12
  %95 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %94, i64 %56, !dbg !898
  %96 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %95, align 8, !dbg !898, !tbaa !650
  %97 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %96, 0, !dbg !898
  %.not19 = icmp eq {} addrspace(10)* %97, null, !dbg !898
  br i1 %.not19, label %fail7, label %pass8, !dbg !898

fail7:                                            ; preds = %idxend6
  call void @ijl_throw({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 4697798208 to {}*) to {} addrspace(12)*)) #38, !dbg !898
  unreachable, !dbg !898

pass8:                                            ; preds = %idxend6
  %98 = bitcast {} addrspace(10)* %48 to { i64, {} addrspace(10)* } addrspace(10)*, !dbg !911
  %99 = addrspacecast { i64, {} addrspace(10)* } addrspace(10)* %98 to { i64, {} addrspace(10)* } addrspace(11)*, !dbg !911
  %100 = getelementptr inbounds { i64, {} addrspace(10)* }, { i64, {} addrspace(10)* } addrspace(11)* %99, i64 0, i32 1, !dbg !911
  %101 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %100 unordered, align 8, !dbg !911, !tbaa !775, !nonnull !12, !dereferenceable !297, !align !298
  %.fca.0.gep23 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10, i64 0, i32 0, !dbg !912
  store {} addrspace(10)* %97, {} addrspace(10)* addrspace(10)* %.fca.0.gep23, align 8, !dbg !912
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* %97), !dbg !912
  %.fca.1.extract24 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %96, 1, !dbg !912
  %.fca.1.gep25 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10, i64 0, i32 1, !dbg !912
  store {} addrspace(10)* %.fca.1.extract24, {} addrspace(10)* addrspace(10)* %.fca.1.gep25, align 8, !dbg !912
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* %.fca.1.extract24), !dbg !912
  %.fca.2.extract26 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %96, 2, !dbg !912
  %.fca.2.gep27 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10, i64 0, i32 2, !dbg !912
  store {} addrspace(10)* %.fca.2.extract26, {} addrspace(10)* addrspace(10)* %.fca.2.gep27, align 8, !dbg !912
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* %.fca.2.extract26), !dbg !912
  %.fca.3.extract28 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %96, 3, !dbg !912
  %.fca.3.gep29 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10, i64 0, i32 3, !dbg !912
  store float %.fca.3.extract28, float addrspace(10)* %.fca.3.gep29, align 8, !dbg !912
  %102 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)*, !dbg !912
  %103 = call fastcc nonnull {} addrspace(10)* @julia_NodeLayer_3468({ {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(32) %102, {} addrspace(10)* nonnull align 4 dereferenceable(4) %27, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %101, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %101) #33, !dbg !912
  br label %L39, !dbg !895

oob9:                                             ; preds = %L12
  %104 = alloca i64, align 8, !dbg !893
  store i64 %39, i64* %104, align 8, !dbg !893
  %105 = addrspacecast {} addrspace(10)* %2 to {} addrspace(12)*, !dbg !893
  call void @ijl_bounds_error_ints({} addrspace(12)* %105, i64* noundef nonnull align 8 %104, i64 noundef 1) #38, !dbg !893
  unreachable, !dbg !893

idxend10:                                         ; preds = %L12
  %106 = bitcast {} addrspace(10)* %2 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)*, !dbg !893
  %107 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)* %106 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)*, !dbg !893
  %108 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)* %107, align 16, !dbg !893, !tbaa !128, !nonnull !12
  %109 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %108, i64 %40, !dbg !893
  %110 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %109, align 8, !dbg !893, !tbaa !650
  %111 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %110, 0, !dbg !893
  %.not21 = icmp eq {} addrspace(10)* %111, null, !dbg !893
  br i1 %.not21, label %fail11, label %pass12, !dbg !893

fail11:                                           ; preds = %idxend10
  call void @ijl_throw({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 4697798208 to {}*) to {} addrspace(12)*)) #38, !dbg !893
  unreachable, !dbg !893

pass12:                                           ; preds = %idxend10
  %112 = bitcast {} addrspace(10)* %32 to { i64, {} addrspace(10)* } addrspace(10)*, !dbg !913
  %113 = addrspacecast { i64, {} addrspace(10)* } addrspace(10)* %112 to { i64, {} addrspace(10)* } addrspace(11)*, !dbg !913
  %114 = getelementptr inbounds { i64, {} addrspace(10)* }, { i64, {} addrspace(10)* } addrspace(11)* %113, i64 0, i32 1, !dbg !913
  %115 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %114 unordered, align 8, !dbg !913, !tbaa !775, !nonnull !12, !dereferenceable !297, !align !298
  %.fca.0.gep = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24, i64 0, i32 0, !dbg !914
  store {} addrspace(10)* %111, {} addrspace(10)* addrspace(10)* %.fca.0.gep, align 8, !dbg !914
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %23, {} addrspace(10)* %111), !dbg !914
  %.fca.1.extract = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %110, 1, !dbg !914
  %.fca.1.gep = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24, i64 0, i32 1, !dbg !914
  store {} addrspace(10)* %.fca.1.extract, {} addrspace(10)* addrspace(10)* %.fca.1.gep, align 8, !dbg !914
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %23, {} addrspace(10)* %.fca.1.extract), !dbg !914
  %.fca.2.extract = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %110, 2, !dbg !914
  %.fca.2.gep = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24, i64 0, i32 2, !dbg !914
  store {} addrspace(10)* %.fca.2.extract, {} addrspace(10)* addrspace(10)* %.fca.2.gep, align 8, !dbg !914
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %23, {} addrspace(10)* %.fca.2.extract), !dbg !914
  %.fca.3.extract = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %110, 3, !dbg !914
  %.fca.3.gep = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24, i64 0, i32 3, !dbg !914
  store float %.fca.3.extract, float addrspace(10)* %.fca.3.gep, align 8, !dbg !914
  %116 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)*, !dbg !914
  %117 = call fastcc nonnull {} addrspace(10)* @julia_NodeLayer_3468({ {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(32) %116, {} addrspace(10)* noundef nonnull align 4 dereferenceable(4) %25, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %115, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %115) #33, !dbg !914
  br label %L21, !dbg !890
}

 Type analysis state:

Illegal type analysis update from julia rule of method MethodInstance for (::Branch)(::Base.RefValue{Float32}, ::Vector{NodeLayer})
Found type Branch at index 1 of {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,16]:Pointer}
Prior type {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}
  %36 = call fastcc nonnull {} addrspace(10)* @julia_Branch_3463({ i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %35, {} addrspace(10)* noundef nonnull align 4 dereferenceable(4) %25, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) #33, !dbg !69

Caused by:
Stacktrace:
 [1] Branch
   @ ~/Documents/projects/earley/enzyme_tree_mwe_explicit.jl:42

Stacktrace:
  [1] julia_type_rule(direction::Int32, ret::Ptr{Enzyme.API.EnzymeTypeTree}, args::Ptr{Ptr{Enzyme.API.EnzymeTypeTree}}, known_values::Ptr{Enzyme.API.IntList}, numArgs::UInt64, val::Ptr{LLVM.API.LLVMOpaqueValue})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:4510
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/7ekWs/src/api.jl:123
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Branch, Vector{NodeLayer}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{typeof(loss), Tuple{Const{Branch}, Const{Branch}, Duplicated{Vector{NodeLayer}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Bool, returnPrimal::Bool, jlrules::Vector{String})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:4762
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Branch, Vector{NodeLayer}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:5854
  [5] _thunk
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6321 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Branch, Vector{NodeLayer}}}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6315
  [7] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6359
  [8] #s836#163
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6419 [inlined]
  [9] var"#s836#163"(F::Any, Fn::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ReturnPrimal::Any, ::Any, #unused#::Type, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
 [11] thunk
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6452 [inlined]
 [12] thunk (repeats 2 times)
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6445 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/7ekWs/src/Enzyme.jl:199 [inlined]
 [14] autodiff
    @ ~/.julia/packages/Enzyme/7ekWs/src/Enzyme.jl:236 [inlined]
 [15] main()
    @ Main ~/Documents/projects/earley/enzyme_tree_mwe_explicit.jl:89
 [16] top-level scope
    @ ~/Documents/projects/earley/enzyme_tree_mwe_explicit.jl:92
wsmoses commented 1 year ago

@freddycct try adding Enzyme.API.strictTypeAnalysis!(false)

freddycct commented 1 year ago

Still get this error

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Fu1YT/src/utils.jl:35
start training
loss(t₁, t₂, θ) = 1280.9806f0
warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Fu1YT/src/utils.jl:35
warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

warning: Linking two modules of different target triples: 'bcloader' is 'arm64-apple-macosx11.0.0' whereas 'text' is 'arm64-apple-darwin21.6.0'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Fu1YT/src/utils.jl:35
ERROR: LoadError: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc nonnull {} addrspace(10)* @preprocess_julia_Branch_3467({ i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* nocapture nofree nonnull readonly align 8 dereferenceable(24) %0, {} addrspace(10)* nonnull writeonly align 4 dereferenceable(4) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) unnamed_addr #34 !dbg !878 {
top:
  %3 = call {}*** @julia.get_pgcstack() #35
  %4 = bitcast {}*** %3 to {}**
  %5 = getelementptr inbounds {}*, {}** %4, i64 -12
  %6 = getelementptr inbounds {}*, {}** %5, i64 14
  %7 = bitcast {}** %6 to i8**
  %8 = load i8*, i8** %7, align 8
  %9 = call noalias nonnull dereferenceable(32) dereferenceable_or_null(32) {} addrspace(10)* @jl_gc_alloc_typed(i8* %8, i64 32, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5268269792 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !879
  %10 = bitcast {} addrspace(10)* %9 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)*, !enzyme_caststack !12
  %11 = bitcast {}*** %3 to {}**
  %12 = getelementptr inbounds {}*, {}** %11, i64 -12
  %13 = getelementptr inbounds {}*, {}** %12, i64 14
  %14 = bitcast {}** %13 to i8**
  %15 = load i8*, i8** %14, align 8
  %16 = call noalias nonnull dereferenceable(32) dereferenceable_or_null(32) {} addrspace(10)* @jl_gc_alloc_typed(i8* %15, i64 32, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5268269792 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !879
  %17 = bitcast {} addrspace(10)* %16 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)*, !enzyme_caststack !12
  %18 = bitcast {}*** %3 to {}**
  %19 = getelementptr inbounds {}*, {}** %18, i64 -12
  %20 = getelementptr inbounds {}*, {}** %19, i64 14
  %21 = bitcast {}** %20 to i8**
  %22 = load i8*, i8** %21, align 8
  %23 = call noalias nonnull dereferenceable(32) dereferenceable_or_null(32) {} addrspace(10)* @jl_gc_alloc_typed(i8* %22, i64 32, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5268269792 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !879
  %24 = bitcast {} addrspace(10)* %23 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)*, !enzyme_caststack !12
  %current_task113 = getelementptr inbounds {}**, {}*** %3, i64 -12, !dbg !880
  %current_task1 = bitcast {}*** %current_task113 to {}**, !dbg !880
  %25 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 4, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5272990864 to {}*) to {} addrspace(10)*)) #36, !dbg !880
  %26 = bitcast {} addrspace(10)* %25 to float addrspace(10)*, !dbg !880
  store float 0.000000e+00, float addrspace(10)* %26, align 4, !dbg !880, !tbaa !76
  %27 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 4, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5272990864 to {}*) to {} addrspace(10)*)) #36, !dbg !883
  %28 = bitcast {} addrspace(10)* %27 to float addrspace(10)*, !dbg !883
  store float 0.000000e+00, float addrspace(10)* %28, align 4, !dbg !883, !tbaa !76
  %29 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 4, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5272990864 to {}*) to {} addrspace(10)*)) #36, !dbg !886
  %30 = bitcast {} addrspace(10)* %29 to float addrspace(10)*, !dbg !886
  store float 0.000000e+00, float addrspace(10)* %30, align 4, !dbg !886, !tbaa !76
  %31 = getelementptr inbounds { i64, {} addrspace(10)*, {} addrspace(10)* }, { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* %0, i64 0, i32 1, !dbg !889
  %32 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %31 unordered, align 8, !dbg !889, !tbaa !57, !invariant.load !12, !nonnull !12
  %33 = call {} addrspace(10)* @julia.typeof({} addrspace(10)* nonnull %32) #37, !dbg !890
  %.not = icmp eq {} addrspace(10)* %33, addrspacecast ({}* inttoptr (i64 5241070352 to {}*) to {} addrspace(10)*), !dbg !890
  br i1 %.not, label %L7, label %L10, !dbg !890

L7:                                               ; preds = %top
  %34 = bitcast {} addrspace(10)* %32 to { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(10)*, !dbg !890
  %35 = addrspacecast { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(10)* %34 to { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)*, !dbg !890
  %36 = call fastcc nonnull {} addrspace(10)* @julia_Branch_3467({ i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %35, {} addrspace(10)* noundef nonnull align 4 dereferenceable(4) %25, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) #33, !dbg !890
  br label %L21, !dbg !890

L10:                                              ; preds = %top
  %.not20 = icmp eq {} addrspace(10)* %33, addrspacecast ({}* inttoptr (i64 5241063952 to {}*) to {} addrspace(10)*), !dbg !890
  br i1 %.not20, label %L12, label %L19, !dbg !890

L12:                                              ; preds = %L10
  %37 = bitcast {} addrspace(10)* %32 to i64 addrspace(10)*, !dbg !891
  %38 = addrspacecast i64 addrspace(10)* %37 to i64 addrspace(11)*, !dbg !891
  %39 = load i64, i64 addrspace(11)* %38, align 8, !dbg !893, !tbaa !775
  %40 = add i64 %39, -1, !dbg !893
  %41 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !893
  %42 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %41 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !893
  %43 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %42, i64 0, i32 1, !dbg !893
  %44 = load i64, i64 addrspace(11)* %43, align 8, !dbg !893, !tbaa !63, !range !59
  %45 = icmp ult i64 %40, %44, !dbg !893
  br i1 %45, label %idxend10, label %oob9, !dbg !893

L19:                                              ; preds = %L10
  %46 = call cc37 nonnull {} addrspace(10)* bitcast ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic to {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* nonnull %32, {} addrspace(10)* nonnull align 4 dereferenceable(4) %25, {} addrspace(10)* nonnull %2) #35, !dbg !890
  br label %L21, !dbg !890

L21:                                              ; preds = %pass12, %L19, %L7
  %value_phi = phi {} addrspace(10)* [ %36, %L7 ], [ %117, %pass12 ], [ %46, %L19 ]
  %47 = getelementptr inbounds { i64, {} addrspace(10)*, {} addrspace(10)* }, { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* %0, i64 0, i32 2, !dbg !894
  %48 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %47 unordered, align 8, !dbg !894, !tbaa !57, !invariant.load !12, !nonnull !12
  %49 = call {} addrspace(10)* @julia.typeof({} addrspace(10)* nonnull %48) #37, !dbg !895
  %.not16 = icmp eq {} addrspace(10)* %49, addrspacecast ({}* inttoptr (i64 5241070352 to {}*) to {} addrspace(10)*), !dbg !895
  br i1 %.not16, label %L25, label %L28, !dbg !895

L25:                                              ; preds = %L21
  %50 = bitcast {} addrspace(10)* %48 to { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(10)*, !dbg !895
  %51 = addrspacecast { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(10)* %50 to { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)*, !dbg !895
  %52 = call fastcc nonnull {} addrspace(10)* @julia_Branch_3467({ i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %51, {} addrspace(10)* nonnull align 4 dereferenceable(4) %27, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) #33, !dbg !895
  br label %L39, !dbg !895

L28:                                              ; preds = %L21
  %.not18 = icmp eq {} addrspace(10)* %49, addrspacecast ({}* inttoptr (i64 5241063952 to {}*) to {} addrspace(10)*), !dbg !895
  br i1 %.not18, label %L30, label %L37, !dbg !895

L30:                                              ; preds = %L28
  %53 = bitcast {} addrspace(10)* %48 to i64 addrspace(10)*, !dbg !896
  %54 = addrspacecast i64 addrspace(10)* %53 to i64 addrspace(11)*, !dbg !896
  %55 = load i64, i64 addrspace(11)* %54, align 8, !dbg !898, !tbaa !775
  %56 = add i64 %55, -1, !dbg !898
  %57 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !898
  %58 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %57 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !898
  %59 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %58, i64 0, i32 1, !dbg !898
  %60 = load i64, i64 addrspace(11)* %59, align 8, !dbg !898, !tbaa !63, !range !59
  %61 = icmp ult i64 %56, %60, !dbg !898
  br i1 %61, label %idxend6, label %oob5, !dbg !898

L37:                                              ; preds = %L28
  %62 = call cc37 nonnull {} addrspace(10)* bitcast ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic to {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* nonnull %48, {} addrspace(10)* nonnull %27, {} addrspace(10)* nonnull %2) #35, !dbg !895
  br label %L39, !dbg !895

L39:                                              ; preds = %pass8, %L37, %L25
  %value_phi4 = phi {} addrspace(10)* [ %52, %L25 ], [ %103, %pass8 ], [ %62, %L37 ]
  %63 = getelementptr inbounds { i64, {} addrspace(10)*, {} addrspace(10)* }, { i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* %0, i64 0, i32 0, !dbg !899
  %64 = load i64, i64 addrspace(11)* %63, align 8, !dbg !901, !tbaa !57, !invariant.load !12
  %65 = add i64 %64, -1, !dbg !901
  %66 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !901
  %67 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %66 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !901
  %68 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %67, i64 0, i32 1, !dbg !901
  %69 = load i64, i64 addrspace(11)* %68, align 8, !dbg !901, !tbaa !63, !range !59
  %70 = icmp ult i64 %65, %69, !dbg !901
  br i1 %70, label %idxend, label %oob, !dbg !901

oob:                                              ; preds = %L39
  %71 = alloca i64, align 8, !dbg !901
  store i64 %64, i64* %71, align 8, !dbg !901
  %72 = addrspacecast {} addrspace(10)* %2 to {} addrspace(12)*, !dbg !901
  call void @ijl_bounds_error_ints({} addrspace(12)* %72, i64* noundef nonnull align 8 %71, i64 noundef 1) #38, !dbg !901
  unreachable, !dbg !901

idxend:                                           ; preds = %L39
  %73 = bitcast {} addrspace(10)* %2 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)*, !dbg !901
  %74 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)* %73 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)*, !dbg !901
  %75 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)* %74, align 16, !dbg !901, !tbaa !128, !nonnull !12
  %76 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %75, i64 %65, !dbg !901
  %77 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %76, align 8, !dbg !901, !tbaa !650
  %78 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %77, 0, !dbg !901
  %.not17 = icmp eq {} addrspace(10)* %78, null, !dbg !901
  br i1 %.not17, label %fail, label %pass, !dbg !901

fail:                                             ; preds = %idxend
  call void @ijl_throw({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 4764939840 to {}*) to {} addrspace(12)*)) #38, !dbg !901
  unreachable, !dbg !901

pass:                                             ; preds = %idxend
  %.fca.0.gep31 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17, i64 0, i32 0, !dbg !902
  store {} addrspace(10)* %78, {} addrspace(10)* addrspace(10)* %.fca.0.gep31, align 8, !dbg !902
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %16, {} addrspace(10)* %78), !dbg !902
  %.fca.1.extract32 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %77, 1, !dbg !902
  %.fca.1.gep33 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17, i64 0, i32 1, !dbg !902
  store {} addrspace(10)* %.fca.1.extract32, {} addrspace(10)* addrspace(10)* %.fca.1.gep33, align 8, !dbg !902
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %16, {} addrspace(10)* %.fca.1.extract32), !dbg !902
  %.fca.2.extract34 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %77, 2, !dbg !902
  %.fca.2.gep35 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17, i64 0, i32 2, !dbg !902
  store {} addrspace(10)* %.fca.2.extract34, {} addrspace(10)* addrspace(10)* %.fca.2.gep35, align 8, !dbg !902
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %16, {} addrspace(10)* %.fca.2.extract34), !dbg !902
  %.fca.3.extract36 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %77, 3, !dbg !902
  %.fca.3.gep37 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17, i64 0, i32 3, !dbg !902
  store float %.fca.3.extract36, float addrspace(10)* %.fca.3.gep37, align 8, !dbg !902
  %79 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %17 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)*, !dbg !902
  %80 = call fastcc nonnull {} addrspace(10)* @julia_NodeLayer_3472({ {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(32) %79, {} addrspace(10)* nonnull align 4 dereferenceable(4) %29, {} addrspace(10)* nonnull align 16 dereferenceable(40) %value_phi, {} addrspace(10)* nonnull align 16 dereferenceable(40) %value_phi4) #33, !dbg !902
  %81 = addrspacecast float addrspace(10)* %26 to float addrspace(11)*, !dbg !903
  %82 = load float, float addrspace(11)* %81, align 4, !dbg !903, !tbaa !76
  %83 = addrspacecast float addrspace(10)* %28 to float addrspace(11)*, !dbg !903
  %84 = load float, float addrspace(11)* %83, align 4, !dbg !903, !tbaa !76
  %85 = addrspacecast float addrspace(10)* %30 to float addrspace(11)*, !dbg !903
  %86 = load float, float addrspace(11)* %85, align 4, !dbg !903, !tbaa !76
  %87 = fadd float %82, %84, !dbg !906
  %88 = fadd float %87, %86, !dbg !906
  %89 = bitcast {} addrspace(10)* %1 to float addrspace(10)*, !dbg !908
  store float %88, float addrspace(10)* %89, align 4, !dbg !908, !tbaa !76
  ret {} addrspace(10)* %80, !dbg !910

oob5:                                             ; preds = %L30
  %90 = alloca i64, align 8, !dbg !898
  store i64 %55, i64* %90, align 8, !dbg !898
  %91 = addrspacecast {} addrspace(10)* %2 to {} addrspace(12)*, !dbg !898
  call void @ijl_bounds_error_ints({} addrspace(12)* %91, i64* noundef nonnull align 8 %90, i64 noundef 1) #38, !dbg !898
  unreachable, !dbg !898

idxend6:                                          ; preds = %L30
  %92 = bitcast {} addrspace(10)* %2 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)*, !dbg !898
  %93 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)* %92 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)*, !dbg !898
  %94 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)* %93, align 16, !dbg !898, !tbaa !128, !nonnull !12
  %95 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %94, i64 %56, !dbg !898
  %96 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %95, align 8, !dbg !898, !tbaa !650
  %97 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %96, 0, !dbg !898
  %.not19 = icmp eq {} addrspace(10)* %97, null, !dbg !898
  br i1 %.not19, label %fail7, label %pass8, !dbg !898

fail7:                                            ; preds = %idxend6
  call void @ijl_throw({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 4764939840 to {}*) to {} addrspace(12)*)) #38, !dbg !898
  unreachable, !dbg !898

pass8:                                            ; preds = %idxend6
  %98 = bitcast {} addrspace(10)* %48 to { i64, {} addrspace(10)* } addrspace(10)*, !dbg !911
  %99 = addrspacecast { i64, {} addrspace(10)* } addrspace(10)* %98 to { i64, {} addrspace(10)* } addrspace(11)*, !dbg !911
  %100 = getelementptr inbounds { i64, {} addrspace(10)* }, { i64, {} addrspace(10)* } addrspace(11)* %99, i64 0, i32 1, !dbg !911
  %101 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %100 unordered, align 8, !dbg !911, !tbaa !775, !nonnull !12, !dereferenceable !297, !align !298
  %.fca.0.gep23 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10, i64 0, i32 0, !dbg !912
  store {} addrspace(10)* %97, {} addrspace(10)* addrspace(10)* %.fca.0.gep23, align 8, !dbg !912
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* %97), !dbg !912
  %.fca.1.extract24 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %96, 1, !dbg !912
  %.fca.1.gep25 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10, i64 0, i32 1, !dbg !912
  store {} addrspace(10)* %.fca.1.extract24, {} addrspace(10)* addrspace(10)* %.fca.1.gep25, align 8, !dbg !912
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* %.fca.1.extract24), !dbg !912
  %.fca.2.extract26 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %96, 2, !dbg !912
  %.fca.2.gep27 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10, i64 0, i32 2, !dbg !912
  store {} addrspace(10)* %.fca.2.extract26, {} addrspace(10)* addrspace(10)* %.fca.2.gep27, align 8, !dbg !912
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* %.fca.2.extract26), !dbg !912
  %.fca.3.extract28 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %96, 3, !dbg !912
  %.fca.3.gep29 = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10, i64 0, i32 3, !dbg !912
  store float %.fca.3.extract28, float addrspace(10)* %.fca.3.gep29, align 8, !dbg !912
  %102 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %10 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)*, !dbg !912
  %103 = call fastcc nonnull {} addrspace(10)* @julia_NodeLayer_3472({ {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(32) %102, {} addrspace(10)* nonnull align 4 dereferenceable(4) %27, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %101, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %101) #33, !dbg !912
  br label %L39, !dbg !895

oob9:                                             ; preds = %L12
  %104 = alloca i64, align 8, !dbg !893
  store i64 %39, i64* %104, align 8, !dbg !893
  %105 = addrspacecast {} addrspace(10)* %2 to {} addrspace(12)*, !dbg !893
  call void @ijl_bounds_error_ints({} addrspace(12)* %105, i64* noundef nonnull align 8 %104, i64 noundef 1) #38, !dbg !893
  unreachable, !dbg !893

idxend10:                                         ; preds = %L12
  %106 = bitcast {} addrspace(10)* %2 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)*, !dbg !893
  %107 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(10)* %106 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)*, !dbg !893
  %108 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* addrspace(11)* %107, align 16, !dbg !893, !tbaa !128, !nonnull !12
  %109 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %108, i64 %40, !dbg !893
  %110 = load { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(13)* %109, align 8, !dbg !893, !tbaa !650
  %111 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %110, 0, !dbg !893
  %.not21 = icmp eq {} addrspace(10)* %111, null, !dbg !893
  br i1 %.not21, label %fail11, label %pass12, !dbg !893

fail11:                                           ; preds = %idxend10
  call void @ijl_throw({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 4764939840 to {}*) to {} addrspace(12)*)) #38, !dbg !893
  unreachable, !dbg !893

pass12:                                           ; preds = %idxend10
  %112 = bitcast {} addrspace(10)* %32 to { i64, {} addrspace(10)* } addrspace(10)*, !dbg !913
  %113 = addrspacecast { i64, {} addrspace(10)* } addrspace(10)* %112 to { i64, {} addrspace(10)* } addrspace(11)*, !dbg !913
  %114 = getelementptr inbounds { i64, {} addrspace(10)* }, { i64, {} addrspace(10)* } addrspace(11)* %113, i64 0, i32 1, !dbg !913
  %115 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %114 unordered, align 8, !dbg !913, !tbaa !775, !nonnull !12, !dereferenceable !297, !align !298
  %.fca.0.gep = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24, i64 0, i32 0, !dbg !914
  store {} addrspace(10)* %111, {} addrspace(10)* addrspace(10)* %.fca.0.gep, align 8, !dbg !914
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %23, {} addrspace(10)* %111), !dbg !914
  %.fca.1.extract = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %110, 1, !dbg !914
  %.fca.1.gep = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24, i64 0, i32 1, !dbg !914
  store {} addrspace(10)* %.fca.1.extract, {} addrspace(10)* addrspace(10)* %.fca.1.gep, align 8, !dbg !914
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %23, {} addrspace(10)* %.fca.1.extract), !dbg !914
  %.fca.2.extract = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %110, 2, !dbg !914
  %.fca.2.gep = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24, i64 0, i32 2, !dbg !914
  store {} addrspace(10)* %.fca.2.extract, {} addrspace(10)* addrspace(10)* %.fca.2.gep, align 8, !dbg !914
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %23, {} addrspace(10)* %.fca.2.extract), !dbg !914
  %.fca.3.extract = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } %110, 3, !dbg !914
  %.fca.3.gep = getelementptr { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float }, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24, i64 0, i32 3, !dbg !914
  store float %.fca.3.extract, float addrspace(10)* %.fca.3.gep, align 8, !dbg !914
  %116 = addrspacecast { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(10)* %24 to { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)*, !dbg !914
  %117 = call fastcc nonnull {} addrspace(10)* @julia_NodeLayer_3472({ {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, float } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(32) %116, {} addrspace(10)* noundef nonnull align 4 dereferenceable(4) %25, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %115, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %115) #33, !dbg !914
  br label %L21, !dbg !890
}

 Type analysis state:

Illegal type analysis update from julia rule of method MethodInstance for (::Branch)(::Base.RefValue{Float32}, ::Vector{NodeLayer})
Found type Branch at index 1 of {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,16]:Pointer}
Prior type {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}
  %36 = call fastcc nonnull {} addrspace(10)* @julia_Branch_3467({ i64, {} addrspace(10)*, {} addrspace(10)* } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %35, {} addrspace(10)* noundef nonnull align 4 dereferenceable(4) %25, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) #33, !dbg !69

Caused by:
Stacktrace:
 [1] Branch
   @ ~/Documents/projects/earley/common.jl:40

Stacktrace:
  [1] julia_type_rule(direction::Int32, ret::Ptr{Enzyme.API.EnzymeTypeTree}, args::Ptr{Ptr{Enzyme.API.EnzymeTypeTree}}, known_values::Ptr{Enzyme.API.IntList}, numArgs::UInt64, val::Ptr{LLVM.API.LLVMOpaqueValue})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:4510
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/7ekWs/src/api.jl:123
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Branch, Vector{NodeLayer}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{typeof(loss), Tuple{Const{Branch}, Const{Branch}, Duplicated{Vector{NodeLayer}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Bool, returnPrimal::Bool, jlrules::Vector{String})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:4762
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Branch, Vector{NodeLayer}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:5854
  [5] _thunk
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6321 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Branch, Branch, Vector{NodeLayer}}}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6315
  [7] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6359
  [8] #s836#163
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6419 [inlined]
  [9] var"#s836#163"(F::Any, Fn::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ReturnPrimal::Any, ::Any, #unused#::Type, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
 [11] thunk
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6452 [inlined]
 [12] thunk (repeats 2 times)
    @ ~/.julia/packages/Enzyme/7ekWs/src/compiler.jl:6445 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/7ekWs/src/Enzyme.jl:199 [inlined]
 [14] autodiff
    @ ~/.julia/packages/Enzyme/7ekWs/src/Enzyme.jl:236 [inlined]
 [15] main()
    @ Main ~/Documents/projects/earley/enzyme_tree_mwe_explicit.jl:25
 [16] top-level scope
    @ ~/Documents/projects/earley/enzyme_tree_mwe_explicit.jl:28
wsmoses commented 1 year ago

GC bug is now resolved, all that's remaining is getfield making this a duplicate of: https://github.com/EnzymeAD/Enzyme.jl/issues/176

freddycct commented 1 year ago

@wsmoses I'm still getting an error for this, can we re-open this ticket?

FAILED CT: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer} adding Seq: [-1,8,0,-1]: Float@float
Assertion failed: (found->second == BaseType::Pointer), function insert, file /workspace/srcdir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeTree.h, line 184.

signal (6): Abort trap: 6
in expression starting at /Users/freddy/Documents/earley/enzyme_tree_mwe.jl:90
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 42365007 (Pool: 42338255; Big: 26752); GC: 39