EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Crash on `OMEinsum.einsum!` #1416

Closed mofeing closed 2 weeks ago

mofeing commented 2 months ago

I'm running into a huge error when trying to run the code below. I'm not sure if I'm missing a rule or is a Enzyme bug.

error.txt

using Enzyme
using OMEinsum

function f(ssa1::Array{Float64, 2}, ssa2::Array{Float64, 2}, ssa3::Array{Float64, 2}, ssa4::Array{Float64, 2}, ssa5::Array{Float64, 2}, ssa6::Array{Float64, 2}, ssa7::Array{Float64, 0})
    einsum!(((:A, :C), (:C, :B)), (:A, :B), (ssa1, ssa2), ssa5, true, false, (OMEinsum.get_size_dict)(((:A, :C), (:C, :B)), (ssa1, ssa2)))
    einsum!(((:A, :C), (:C, :B)), (:A, :B), (ssa3, ssa4), ssa6, true, false, (OMEinsum.get_size_dict)(((:A, :C), (:C, :B)), (ssa3, ssa4)))
    einsum!(((:A, :B), (:A, :B)), (), (ssa5, ssa6), ssa7, true, false, (OMEinsum.get_size_dict)(((:A, :B), (:A, :B)), (ssa5, ssa6)))
    return only(ssa7)
end

x = [rand(2,2) for _ in 1:4]
tmp = [rand(2,2), rand(2,2)]
y = fill(0.0)
∇ = [zero.(x)..., zero.(tmp)..., zero(y)]

f(x..., tmp..., y)

autodiff(Reverse, f, Active, Duplicated.([x..., tmp..., y], ∇)...)
wsmoses commented 2 months ago

hey @mofeing this generates a ton of code and is going to be rather hard to debug as is. Can you reproduce this in a more minimal example?

mofeing commented 2 months ago

I can also replicate this with TensorOperations. Both of this "einsum" packages should be calling BLAS for these simple examples.

using Enzyme
using TensorOperations

function f(ssa1::Array{Float64, 2}, ssa2::Array{Float64, 2}, ssa3::Array{Float64, 2}, ssa4::Array{Float64, 2})
    ssa5 = tensorcontract((1,3), ssa1, (1,2), :N, ssa2, (2,3), :N)
    ssa6 = tensorcontract((1,3), ssa3, (1,2), :N, ssa4, (2,3), :N)
    ssa7 = tensorcontract((), ssa5, (1,2), :N, ssa6, (1,2), :N)
    return only(ssa7)
end

x = [rand(2,2) for _ in 1:4]
∇ = zero.(x)

f(x...)

autodiff(Reverse, f, Active, Duplicated.(x, ∇)...)
wsmoses commented 2 months ago

Going to still need this simplier, it makes hundreds of thousands of lines of IR -- which I can look through, but it'd be much more time efficient otherwise.

mofeing commented 2 months ago

hey @mofeing this generates a ton of code and is going to be rather hard to debug as is. Can you reproduce this in a more minimal example?

The smallest I can do is just taking 2 vectors and doing a dot product.

using Enzyme
using OMEinsum

function f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
    einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
    return only(ssa3)
end

x = [rand(2) for _ in 1:2]
y = fill(0.0)
∇ = [zero.(x)..., zero(y)]

f(x..., y)

autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)

EDIT: Fixed some typos.

wsmoses commented 2 months ago

would you be able to inline the definitions/macros/etc from einsum? (and possibly simplify)

mofeing commented 2 months ago

mmm I managed to skip one layer of OMEinsum and simplify this case to the following:

using Enzyme
using OMEinsum

rule = OMEinsum.SimpleBinaryRule{('j',), ('j',), ()}()

function h(ssa1, ssa2, ssa3)
    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
end

x = [rand(2), rand(2)]
y = fill(0.0)
∇ = map(zero, [x...,y])

h(x..., y)

autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)

And now I'm getting this error:

Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler [~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59](https://file+.vscode-resource.vscode-cdn.net/Users/mofeing/Developer/k-local-gradient-descent/notebooks/~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59)
Enzyme Mutability Error: Cannot add one in place to immutable value fill(0.0)

Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] add_one_in_place
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined]
 [3] augmented_julia_h_9637wrap
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
 [4] macro expansion
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined]
 [5] enzyme_call
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined]
 [6] AugmentedForwardThunk
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined]
 [7] autodiff
   @ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined]
 [8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}})
   @ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303
 [9] top-level scope
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2
wsmoses commented 2 months ago

So that is a different issue, any luck minimizing the previous error?

wsmoses commented 2 months ago

but you should probably replace fill(0) with like zeros(2) to workaround the latter one for now if need be.

mofeing commented 2 months ago

Ah, no. This is the only simplification I could do. It's all or nothing 🥲.

Also, I don't call fill(0.0) inside h... I did a small modification (now it's doing a element-wise multiplication of 2 vectors) and it's strange:

using Enzyme
using OMEinsum

rule = OMEinsum.SimpleBinaryRule{('l',), ('l',), ('l',)}()

function h(ssa1, ssa2, ssa3)
    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
end

x = [rand(2), rand(2)]
y = zeros(2)
∇ = map(zero, [x...,y])

h(x..., y)

autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)

which returns

Enzyme Mutability Error: Cannot add one in place to immutable value [0.0, 0.0]

Stacktrace:
 [1] error
   @ ./error.jl:35
 [2] add_one_in_place
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined]
 [3] augmented_julia_h_13355wrap
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
 [4] macro expansion
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined]
 [5] enzyme_call
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined]
 [6] AugmentedForwardThunk
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined]
 [7] autodiff
   @ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined]
 [8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}})
   @ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303
 [9] top-level scope
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2

I think I'm gonna need to add manually some rules for Enzyme and TensorOperations.

wsmoses commented 2 months ago

So the last error implies you are returning an array not a scalar

On Tue, May 7, 2024 at 4:15 PM Sergio Sánchez Ramírez < @.***> wrote:

Ah, no. This is the only simplification I could do. It's all or nothing 🥲.

Also, I don't call fill(0.0) inside h... I did a small modification (now it's doing a element-wise multiplication of 2 vectors) and it's strange:

using Enzymeusing OMEinsum

rule = OMEinsum.SimpleBinaryRule{('l',), ('l',), ('l',)}() function h(ssa1, ssa2, ssa3) OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)end

x = [rand(2), rand(2)] y = zeros(2) ∇ = map(zero, [x...,y]) h(x..., y) autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)

which returns

Enzyme Mutability Error: Cannot add one in place to immutable value [0.0, 0.0]

Stacktrace: [1] error @ ./error.jl:35 [2] add_one_in_place @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined] [3] augmented_julia_h_13355wrap @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0 [4] macro expansion @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined] [5] enzyme_call @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined] [6] AugmentedForwardThunk @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined] [7] autodiff @ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined] [8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}) @ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303 [9] top-level scope @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2

I think I'm gonna need to add manually some rules for Enzyme and TensorOperations.

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1416#issuecomment-2099224257, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXCNGEJSHEFHJNJM77TZBEY45AVCNFSM6AAAAABHJ3YSESVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJZGIZDIMRVG4 . You are receiving this because you commented.Message ID: <EnzymeAD/Enzyme. @.***>

mofeing commented 2 months ago

Yes, in this last example is returning an array, and in the previous example a scalar.

But in both cases I get Enzyme Mutability Error: Cannot add one in place to immutable value.

wsmoses commented 2 months ago

You cannot return an array when the return is marked active, you must return a scalar, so you should do only(...) for the last program

mofeing commented 2 months ago

Ah okay, I just found the limitation in autodiff(::Reverse)

If I add a only to h,

function h(ssa1, ssa2, ssa3)
    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
    return only(ssa3)
end

∇ = map(zero, [ssa1,ssa2,ssa3])
autodiff(Reverse, h, Active, Duplicated.([ssa1,ssa2,ssa3], ∇)...)

Then, I get:

julia> ∇
3-element Vector{Array{Float64}}:
 [0.26640132425362695, 0.906387409016582]
 [0.6451817492001096, 0.9773755209533468]
 fill(0.0)
mofeing commented 2 months ago

So I guess the problem is in between einsum! and binary_einsum!, which is the code in... https://github.com/under-Peter/OMEinsum.jl/blob/327cf355c746e9f646c5beee74dcd2c11aa90240/src/einsum.jl#L99-L117

wsmoses commented 2 months ago

The mutability error is not a significant issue (it usually means you returned an array rather than scalar like here). The other issue with the long trace is the one I need a MWE for to fix

wsmoses commented 1 month ago

@mofeing I merged a jll bump which fixes some phi node issues. Can you seee if this persists?

If not, we should close.

mofeing commented 1 month ago

nothing seems to have changed 🥲

but I think I might me on the way of having a MWE (this issue might not be the only source of error). I noticed that the einsum! function uses @debug and I can get Enzyme to run indefinitely by making it print a variable.

this works correctly

using OMEinsum
using Enzyme

x = [rand(2) for _ in 1:2]
y = zeros()
ssa1, ssa2, ssa3 = x..., y

function u(ssa1, ssa2, ssa3)
    rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()

    @debug "asdf"

    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
    return only(ssa3)
end

this runs indefinitely

using OMEinsum
using Enzyme

x = [rand(2) for _ in 1:2]
y = zeros()
ssa1, ssa2, ssa3 = x..., y

function u(ssa1, ssa2, ssa3)
    rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()

    @debug "asdf" rule

    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
    return only(ssa3)
end
wsmoses commented 1 month ago

that's a useful issue to minimize/understand, but separate from the one you found. If you can minimize either individual issue, we can fix that issue.

mofeing commented 1 month ago

I've updated to the latest Enzyme (v0.12.6) and the error seems a lil bit different (No more big explosions!). For the function f below,

function f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
    einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
    return only(ssa3)
end

the error is the following:

Error on `autodiff(Reverse, f, ...)` ```julia julia> ssa1, ssa2, ssa3 = rand(2), rand(2), zeros() julia> ∇ = map(zero, [ssa1, ssa2, ssa3]) julia> autodiff(Reverse, f, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...) Enzyme execution failed. Mismatched activity for: store {} addrspace(10)* %.fca.0.0.1.0.0.extract5, {} addrspace(10)* addrspace(10)* %.fca.0.0.1.0.0.gep6, align 8, !dbg !72, !noalias !80 const val: %.fca.0.0.1.0.0.extract5 = extractvalue [2 x [1 x {} addrspace(10)*]] %0, 0, 0, !dbg !72 Type tree: {[-1]:Pointer} You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now Stacktrace: [1] ntuple @ ./ntuple.jl:49 [2] copy @ ./broadcast.jl:1118 [3] materialize @ ./broadcast.jl:903 [4] einsum! @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:100 [5] einsum! @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0 Stacktrace: [1] throwerr(cstr::Cstring) @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:1325 [2] ntuple @ ./ntuple.jl:49 [inlined] [3] copy @ ./broadcast.jl:1118 [inlined] [4] materialize @ ./broadcast.jl:903 [inlined] [5] einsum! @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:100 [inlined] [6] einsum! @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0 [inlined] [7] augmented_julia_einsum__4300_inner_1wrap @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0 [8] macro expansion @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined] [9] enzyme_call @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined] [10] AugmentedForwardThunk @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5275 [inlined] [11] runtime_generic_augfwd(activity::Type{Val{(false, false, false, true, true, false, false, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true, false, true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::typeof(einsum!), df::Nothing, primal_1::Tuple{Tuple{Symbol}, Tuple{Symbol}}, shadow_1_1::Nothing, primal_2::Tuple{}, shadow_2_1::Nothing, primal_3::Tuple{Vector{Float64}, Vector{Float64}}, shadow_3_1::Tuple{Vector{Float64}, Vector{Float64}}, primal_4::Array{Float64, 0}, shadow_4_1::Array{Float64, 0}, primal_5::Bool, shadow_5_1::Nothing, primal_6::Bool, shadow_6_1::Nothing, primal_7::Dict{Symbol, Int64}, shadow_7_1::Nothing) @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/rules/jitrules.jl:179 [12] f @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:5 [inlined] [13] diffejulia_f_2479wrap @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0 [14] macro expansion @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined] [15] enzyme_call @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined] [16] CombinedAdjointThunk @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5264 [inlined] [17] autodiff @ ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:291 [inlined] [18] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(f), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}}) @ Enzyme ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:303 [19] top-level scope @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2 ```

I've taken the source code of the OMEinsum.einsum! method that I'm using and removed the @debug expressions and a if branch that is not taken (only the else part is run). The result is the u method below:

function u(ssa1, ssa2, ssa3)
    LT = Symbol
    ixs = ((:A,), (:A,))
    iy = ()

    iyv = OMEinsum._collect(LT,iy)
    ix1v,ix2v = OMEinsum._collect.(Ref(LT), ixs)

    size_dict = (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2))

    c1, c2, cy, s1, s2, s3, i1, i2, iyb = OMEinsum.analyze_binary(ix1v, ix2v, iyv, size_dict)
    rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()

    xs1 = OMEinsum.simplifyto(ix1v, c1, ssa1, size_dict)
    xs2 = OMEinsum.simplifyto(ix2v, c2, ssa2, size_dict)
    x1_ = OMEinsum.safe_reshape(xs1, s1)
    x2_ = OMEinsum.safe_reshape(xs2, s2)

    OMEinsum.binary_einsum!(rule, x1_, x2_, ssa3, true, false)
    return only(ssa3)
end

This u method fails similarly (but not equally, look for example the Type tree) as f:

Error on `autodiff(Reverse, u, ...)` ```julia julia> autodiff(Reverse, u, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...) Enzyme execution failed. Mismatched activity for: %unbox24.fca.4.load.pn = phi {} addrspace(10)* [ %unbox24.fca.4.load, %L60 ], [ %unbox31.unpack61, %L64 ] const val: %unbox24.fca.4.load = load {} addrspace(10)*, {} addrspace(10)** %unbox24.fca.4.gep, align 8, !dbg !136 Type tree: {} You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now Stacktrace: [1] u @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:12 Stacktrace: [1] throwerr(cstr::Cstring) @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:1325 [2] iterate @ ./range.jl:897 [inlined] [3] copyto! @ ./abstractarray.jl:942 [inlined] [4] _collect @ ./array.jl:696 [inlined] [5] collect @ ./array.jl:694 [inlined] [6] #91 @ ./none:0 [inlined] [7] iterate @ ./generator.jl:47 [inlined] [8] collect @ ./array.jl:834 [inlined] [9] get_size_dict! @ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:61 [inlined] [10] get_size_dict @ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:100 [inlined] [11] get_size_dict @ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:99 [inlined] [12] u @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:9 [inlined] [13] diffejulia_u_5824wrap @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0 [14] macro expansion @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined] [15] enzyme_call @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined] [16] CombinedAdjointThunk @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5264 [inlined] [17] autodiff @ ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:291 [inlined] [18] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(u), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}}) @ Enzyme ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:303 [19] top-level scope @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:26 ```

By setting OMEinsum.analyze_binary to be inactive, u works but f continues to give the same error:

julia> EnzymeRules.inactive(::typeof(OMEinsum.analyze_binary), args...) = nothing

julia> ∇ = map(zero, [ssa1,ssa2,ssa3])

julia> autodiff(Reverse, u, Active, Duplicated.([ssa1,ssa2,ssa3], ∇)...)
┌ Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59

((nothing, nothing, nothing),)

julia> ∇
3-element Vector{Array{Float64}}:
 [0.056936369463587044, 0.29215353698534263]
 [0.4591629765253664, 0.672870606539631]
 fill(0.0)
mofeing commented 1 month ago

This is not working either...

function custom_einsum!(ixs, iy, @nospecialize(xs::NTuple{2, Any}), @nospecialize(y), sx, sy, size_dict::Dict{LT}) where LT
    iyv = OMEinsum._collect(LT,iy)
    ix1v, ix2v = OMEinsum._collect.(Ref(LT), ixs)

    x1, x2 = xs
    c1, c2, cy, s1, s2, s3, i1, i2, iyb = OMEinsum.analyze_binary(ix1v, ix2v, iyv, size_dict)
    rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()
    xs1 = OMEinsum.simplifyto(ix1v, c1, x1, size_dict)
    xs2 = OMEinsum.simplifyto(ix2v, c2, x2, size_dict)
    x1_ = OMEinsum.safe_reshape(xs1, s1)
    x2_ = OMEinsum.safe_reshape(xs2, s2)

    # if cy != iyv
    #     y_ = similar(y, (s3...,))
    #     y_ = reshape(OMEinsum.binary_einsum!(rule, x1_, x2_, y_, true, false), [size_dict[x] for x in cy]...)
    #     return custom_einsum!((cy,), iyv, (y_,), y, sx, sy, size_dict)
    # else
        OMEinsum.binary_einsum!(rule, x1_, x2_, OMEinsum.safe_reshape(y, s3), sx, sy)
        return y
    # end
end

function custom_f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
    custom_einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
    return only(ssa3)
end

x = [rand(2) for _ in 1:2]
y = zeros()

custom_f(x..., y)

∇ = map(zero, [ssa1, ssa2, ssa3])
autodiff(Reverse, custom_f, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...)

Could it be that the problem is around the argument passing of einsum!? Or maybe the u example works because there is some kind of constant folding/propagation happening before LLVM?

wsmoses commented 1 month ago

@mofeing with a bunch of fixes now landed, how does this work presently?