EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 64 forks source link

Enzyme is dropping gradients when using custom rule and views #1856

Open ptiede opened 1 month ago

ptiede commented 1 month ago

I've managed to get into a situation where a custom rule seems incorrect when used within a loop. Here is a MWE

vl = [2, 2]
tot = sum(vl)
rg = [1:2, 3:4]
Nx = 4

iminds = reshape([CartesianIndex(i) for i in 1:2], :)
visinds = [collect(rg[i]) for i in eachindex(rg)]
Bs = Dict((iminds[i]=> ones(ComplexF64, vl[i], Nx*Nx) for i in eachindex(vl)))

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end

@noinline function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
                                       ::Const{typeof(_mul!)},
                                       ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
                                       b::Duplicated)

    b.dval .+= real.(A.val' * out.dval)
    out.dval .= 0
    return (nothing, nothing, nothing)
end

@inline function f(Bs, visinds, iminds, tot, x)
    out = similar(x, Complex{eltype(x)}, tot)
    for i in eachindex(iminds, visinds)
        imind = iminds[i]
        visind = visinds[i]
        _mul!(@view(out[visind]), Bs[imind], reshape(@view(x[:, :, imind]), :))
    end
    return sum(abs2, out)
end

x = ones(Nx, Nx, 2)
dx = zero(x)

f(Bs, visinds, iminds, tot, x)
autodiff(set_runtime_activity(Reverse), f, Active, Const(Bs), Const(visinds), Const(iminds), Const(tot), Duplicated(x, fill!(dx, 0)))
@show dx
# dx = [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0;;; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0]

So the gradients for x[:,:,1] are zero'd rather than also filled with 64. Funny enough if I change vl to be different e.g.,

vl = [2, 3]
tot = sum(vl)
rg = [1:2, 3:5]
Nx = 4

and keep everything else identical I get a out of bounds error

ERROR: DimensionMismatch: matrix A has dimensions (16,2), vector B has length 3
  [1] _generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:697
  [2] generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:687 [inlined]
  [3] mul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
  [4] mul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
  [5] *
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:57
  [6] reverse
    @ ~/Research/Enzyme/dft.jl:32 [inlined]
  [7] f
    @ ~/Research/Enzyme/dft.jl:42 [inlined]
  [8] diffejulia_f_7370wrap
    @ ~/Research/Enzyme/dft.jl:0
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:7045 [inlined]
 [10] enzyme_call
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:6648 [inlined]
 [11] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:6525 [inlined]
 [12] autodiff
    @ ~/.julia/packages/Enzyme/0PGOL/src/Enzyme.jl:316 [inlined]
 [13] autodiff(::ReverseMode{…}, ::typeof(f), ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/0PGOL/src/Enzyme.jl:328
 [14] top-level scope
    @ ~/Research/Enzyme/dft.jl:51

so it looks like the shadow is incorrect.

Note that when acting on a single array the rule is correct, e.g.,

function fsimple(Bs, x)
    out = similar(Bs, Complex{eltype(x)}, size(Bs, 1))
    _mul!(out, Bs, reshape(x, :))
    return sum(abs2, out)
end

B = ones(ComplexF64, 2, 16)
xx = ones(4,4)
dxx = zero(xx)
autodiff(set_runtime_activity(Reverse), fsimple, Active, Const(B), Duplicated(xx, dxx))
@show dxx
# dxx = [64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0]
ptiede commented 1 month ago

Oh also this is on Enzyme#main

(Enzyme) pkg> st
Status `~/Research/Enzyme/Project.toml`
  [7da242da] Enzyme v0.13.0 `https://github.com/EnzymeAD/Enzyme.jl.git#main`
  [f151be2c] EnzymeCore v0.8.0 `https://github.com/EnzymeAD/Enzyme.jl.git:lib/EnzymeCore#main`
wsmoses commented 1 month ago

@ptiede can you simplify this? I tried to do so but I really don't understand what is supposed to be happening

ptiede commented 1 month ago

I've reduced it a bit. It is essentially just a loop over a matvec where the vec is a view. Let me know if this isn't enough.

using Enzyme
using EnzymeCore: EnzymeRules
using LinearAlgebra

Nx = 2
Bs = [ones(ComplexF64, Nx, Nx*Nx) for i in 1:Nx]

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end

function _mul(A, b)
    out = similar(A, size(A, 1))
    _mul!(out, A, b)
    return out
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
    b::Duplicated)

     b.dval .+= real.(A.val' * out.dval)
     out.dval .= 0
    return (nothing, nothing, nothing)
end

@inline function f2(Nx, Bs, x)
    s = zero(eltype(x))
    for i in 1:Nx
        @inbounds s += sum(abs2, _mul(Bs[i], reshape(@view(x[:, :, i]), :)))
    end
    return s
end

x = ones(Nx, Nx, 2)
dx = zero(x)

f2(Nx, Bs, x)
autodiff(set_runtime_activity(Reverse), f2, Active, Const(Nx), Const(Bs), Duplicated(x, fill!(dx, 0)))
@show dx
# dx = [0.0 0.0; 0.0 0.0;;; 32.0 32.0; 32.0 32.0]

# correct answer is fill(16.0, Nx, Nx, Nx)
ptiede commented 1 month ago

Further reduction

using Enzyme, LinearAlgebra
Nx = 2
Bs = ones(ComplexF64, Nx, Nx*Nx)

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end

function _mul(A, b)
    out = similar(A, size(A, 1))
    _mul!(out, A, b)
    return out
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
    b::Duplicated)

b.dval .+= real.(A.val' * out.dval)
out.dval .= 0
return (nothing, nothing, nothing)
end

@inline function f2(Nx, Bs, x)
    s = zero(eltype(x))
    for i in 1:Nx
        @inbounds s += sum(abs2, _mul(Bs, reshape(@view(x[:, :, i]), :)))
    end
    return s
end

x = ones(Nx, Nx, 2)
dx = zero(x)

f2(Nx, Bs, x)
autodiff(set_runtime_activity(Reverse), f2, Active, Const(Nx), Const(Bs), Duplicated(x, fill!(dx, 0)))
@show dx
wsmoses commented 1 month ago

what is this supposed to give/why is it wrong?

For one thing you're storing into constant data (Bs) which means that it has a zero derivative guaranteed. see https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage

wsmoses commented 1 month ago
using Enzyme, LinearAlgebra
using EnzymeCore: EnzymeRules

function _mul!(b)
    return b[1]*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      b::Duplicated)
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(b.val) : nothing, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape,  b::Duplicated)
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += 2 * b.val[1] * out.val
    @show b.dval
    @show out
    return (nothing,)
end

@inline function f2(b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = _mul!(@view(b[:, i]))
        s += out
    end
    return s
end

Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

f2(b)
autodiff(Reverse, f2, Active, Duplicated(b, fill!(db, 0)))
# What Enzyme returns db = [0  16]
# What the correct answer is db = [2 8]
ptiede commented 1 month ago
after simplification :
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f2_11891({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="125879341910352" "enzymejl_parmtype_ref"="2" %0) local_unnamed_addr #11 !dbg !154 {
top:
  %newstruct = alloca { [1 x [1 x i64]], i64 }, align 8
  %1 = alloca { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, align 8
  %2 = call {}*** @julia.get_pgcstack() #12
  %ptls_field31 = getelementptr inbounds {}**, {}*** %2, i64 2
  %3 = bitcast {}*** %ptls_field31 to i64***
  %ptls_load3233 = load i64**, i64*** %3, align 8, !tbaa !11
  %4 = getelementptr inbounds i64*, i64** %ptls_load3233, i64 2
  %safepoint = load i64*, i64** %4, align 8, !tbaa !15
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #12, !dbg !155
  fence syncscope("singlethread") seq_cst
  %5 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !156
  %arraysize_ptr = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %5, i64 4, !dbg !156
  %6 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr to i64 addrspace(11)*, !dbg !156
  %arraysize = load i64, i64 addrspace(11)* %6, align 16, !dbg !156, !tbaa !15, !range !22, !alias.scope !23, !noalias !26, !enzyme_inactive !10
  %.not = icmp eq i64 %arraysize, 0, !dbg !158
  br i1 %.not, label %L67, label %L18.preheader, !dbg !157

L18.preheader:                                    ; preds = %top
  %arraysize_ptr8 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %5, i64 3
  %7 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr8 to i64 addrspace(11)*
  %memcpy_refined_dst = getelementptr inbounds { [1 x [1 x i64]], i64 }, { [1 x [1 x i64]], i64 }* %newstruct, i64 0, i32 0, i64 0, i64 0
  %8 = getelementptr inbounds { [1 x [1 x i64]], i64 }, { [1 x [1 x i64]], i64 }* %newstruct, i64 0, i32 1
  %.fca.0.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 0
  %.fca.1.0.0.0.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 1, i32 0, i64 0, i64 0
  %.fca.1.1.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 1, i32 1
  %.fca.2.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 2
  %.fca.3.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 3
  %9 = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1 to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*
  br label %L18, !dbg !162

L18:                                              ; preds = %L42, %L18.preheader
  %iv = phi i64 [ %iv.next, %L42 ], [ 0, %L18.preheader ]
  %value_phi7 = phi double [ %14, %L42 ], [ 0.000000e+00, %L18.preheader ]
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !165
  %arraysize9 = load i64, i64 addrspace(11)* %7, align 8, !dbg !165, !tbaa !15, !range !22, !alias.scope !23, !noalias !26, !enzyme_inactive !10
  %arraysize12 = load i64, i64 addrspace(11)* %6, align 16, !dbg !169, !tbaa !15, !range !22, !alias.scope !23, !noalias !26, !enzyme_inactive !10
  %10 = add nsw i64 %iv.next, -1, !dbg !172
  %.not34 = icmp ult i64 %10, %arraysize12, !dbg !176
  br i1 %.not34, label %L42, label %L39, !dbg !162

L39:                                              ; preds = %L18
  store i64 %arraysize9, i64* %memcpy_refined_dst, align 8, !dbg !177, !tbaa !68, !alias.scope !70, !noalias !178
  store i64 %iv.next, i64* %8, align 8, !dbg !177, !tbaa !68, !alias.scope !70, !noalias !178
  %11 = addrspacecast { [1 x [1 x i64]], i64 }* %newstruct to { [1 x [1 x i64]], i64 } addrspace(11)*, !dbg !162
  call fastcc void @julia_throw_boundserror_11895({} addrspace(10)* nofree noundef nonnull align 16 dereferenceable(40) %0, { [1 x [1 x i64]], i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %11) #13, !dbg !162
  unreachable, !dbg !162

L42:                                              ; preds = %L18
  %12 = mul i64 %arraysize9, %10, !dbg !181
  store {} addrspace(10)* %0, {} addrspace(10)** %.fca.0.gep, align 8, !dbg !164, !noalias !191
  store i64 %arraysize9, i64* %.fca.1.0.0.0.gep, align 8, !dbg !164, !noalias !191
  store i64 %iv.next, i64* %.fca.1.1.gep, align 8, !dbg !164, !noalias !191
  store i64 %12, i64* %.fca.2.gep, align 8, !dbg !164, !noalias !191
  store i64 1, i64* %.fca.3.gep, align 8, !dbg !164, !noalias !191
  %13 = call double @julia__mul__11897({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %9) #12, !dbg !164
  %14 = fadd double %value_phi7, %13, !dbg !192
  %.not35 = icmp eq i64 %iv.next, %arraysize, !dbg !194
  %15 = add nuw nsw i64 %iv.next, 1, !dbg !195
  br i1 %.not35, label %L67.loopexit, label %L18, !dbg !196

L67.loopexit:                                     ; preds = %L42
  store i64 %arraysize9, i64* %memcpy_refined_dst, align 8, !dbg !177, !tbaa !68, !alias.scope !70, !noalias !178
  store i64 %arraysize, i64* %8, align 8, !dbg !177, !tbaa !68, !alias.scope !70, !noalias !178
  br label %L67, !dbg !197

L67:                                              ; preds = %L67.loopexit, %top
  %value_phi23 = phi double [ 0.000000e+00, %top ], [ %14, %L67.loopexit ]
  ret double %value_phi23, !dbg !197
}

; Function Attrs: mustprogress willreturn
define internal void @diffejulia_f2_11891({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="125879341910352" "enzymejl_parmtype_ref"="2" %0, {} addrspace(10)* align 16 "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="125879341910352" "enzymejl_parmtype_ref"="2" %"'", double %differeturn) local_unnamed_addr #11 !dbg !205 {
top:
  %1 = alloca { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, align 8
  %2 = alloca { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, align 8
  %newstruct.i = alloca [1 x i64], align 8
  %newstruct12.i = alloca [1 x i64], align 8
  %"iv'ac" = alloca i64, align 8
  %"value_phi23'de" = alloca double, align 8
  %3 = getelementptr double, double* %"value_phi23'de", i64 0
  store double 0.000000e+00, double* %3, align 8
  %"value_phi7'de" = alloca double, align 8
  %4 = getelementptr double, double* %"value_phi7'de", i64 0
  store double 0.000000e+00, double* %4, align 8
  %"'de" = alloca double, align 8
  %5 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %5, align 8
  %"'de2" = alloca double, align 8
  %6 = getelementptr double, double* %"'de2", i64 0
  store double 0.000000e+00, double* %6, align 8
  %7 = call {}*** @julia.get_pgcstack()
  %8 = call {}*** @julia.get_pgcstack()
  %9 = call {}*** @julia.get_pgcstack()
  %newstruct = alloca { [1 x [1 x i64]], i64 }, align 8
  %"'ipa" = alloca { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, align 8
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } zeroinitializer, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", align 8
  %10 = alloca { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, align 8
  %11 = call {}*** @julia.get_pgcstack() #20
  %ptls_field31 = getelementptr inbounds {}**, {}*** %11, i64 2
  %12 = bitcast {}*** %ptls_field31 to i64***
  %ptls_load3233 = load i64**, i64*** %12, align 8, !tbaa !18, !alias.scope !206, !noalias !209
  %13 = getelementptr inbounds i64*, i64** %ptls_load3233, i64 2
  %safepoint = load i64*, i64** %13, align 8, !tbaa !22, !alias.scope !211, !noalias !214
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #20, !dbg !216
  fence syncscope("singlethread") seq_cst
  %14 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !217
  %arraysize_ptr = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 4, !dbg !217
  %15 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr to i64 addrspace(11)*, !dbg !217
  %arraysize = load i64, i64 addrspace(11)* %15, align 16, !dbg !217, !tbaa !22, !range !29, !alias.scope !219, !noalias !222, !enzyme_inactive !17
  %.not = icmp eq i64 %arraysize, 0, !dbg !224
  br i1 %.not, label %L67, label %L18.preheader, !dbg !218

L18.preheader:                                    ; preds = %top
  %arraysize_ptr8 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 3
  %16 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr8 to i64 addrspace(11)*
  %memcpy_refined_dst = getelementptr inbounds { [1 x [1 x i64]], i64 }, { [1 x [1 x i64]], i64 }* %newstruct, i64 0, i32 0, i64 0, i64 0
  %17 = getelementptr inbounds { [1 x [1 x i64]], i64 }, { [1 x [1 x i64]], i64 }* %newstruct, i64 0, i32 1
  %".fca.0.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 0
  %.fca.0.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 0
  %".fca.1.0.0.0.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 1, i32 0, i64 0, i64 0
  %.fca.1.0.0.0.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 1, i32 0, i64 0, i64 0
  %".fca.1.1.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 1, i32 1
  %.fca.1.1.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 1, i32 1
  %".fca.2.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 2
  %.fca.2.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 2
  %".fca.3.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 3
  %.fca.3.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 3
  %"'ipc" = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa" to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*
  %18 = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10 to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*
  %19 = add nsw i64 %arraysize, -1, !dbg !228
  br label %L18, !dbg !228

L18:                                              ; preds = %L42, %L18.preheader
  %iv = phi i64 [ %iv.next, %L42 ], [ 0, %L18.preheader ]
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !231
  %arraysize9 = load i64, i64 addrspace(11)* %16, align 8, !dbg !231, !tbaa !22, !range !29, !alias.scope !219, !noalias !222, !enzyme_inactive !17
  %arraysize12 = load i64, i64 addrspace(11)* %15, align 16, !dbg !235, !tbaa !22, !range !29, !alias.scope !219, !noalias !222, !enzyme_inactive !17
  %20 = add nsw i64 %iv.next, -1, !dbg !238
  %.not34 = icmp ult i64 %20, %arraysize12, !dbg !242
  br i1 %.not34, label %L42, label %L39, !dbg !228

L39:                                              ; preds = %L18
  store i64 %arraysize9, i64* %memcpy_refined_dst, align 8, !dbg !243, !tbaa !75, !alias.scope !77, !noalias !244
  store i64 %iv.next, i64* %17, align 8, !dbg !243, !tbaa !75, !alias.scope !77, !noalias !244
  %21 = addrspacecast { [1 x [1 x i64]], i64 }* %newstruct to { [1 x [1 x i64]], i64 } addrspace(11)*, !dbg !228
  call fastcc void @julia_throw_boundserror_11895({} addrspace(10)* nofree noundef nonnull align 16 dereferenceable(40) %0, { [1 x [1 x i64]], i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %21) #21, !dbg !228
  unreachable, !dbg !228

L42:                                              ; preds = %L18
  %22 = mul i64 %arraysize9, %20, !dbg !247
  store {} addrspace(10)* %"'", {} addrspace(10)** %".fca.0.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store {} addrspace(10)* %0, {} addrspace(10)** %.fca.0.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  store i64 %arraysize9, i64* %".fca.1.0.0.0.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store i64 %arraysize9, i64* %.fca.1.0.0.0.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  store i64 %iv.next, i64* %".fca.1.1.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store i64 %iv.next, i64* %.fca.1.1.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  store i64 %22, i64* %".fca.2.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store i64 %22, i64* %.fca.2.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  store i64 1, i64* %".fca.3.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store i64 1, i64* %.fca.3.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  %23 = bitcast {}*** %9 to {}**, !dbg !230
  %24 = getelementptr inbounds {}*, {}** %23, i64 -14, !dbg !230
  %25 = getelementptr inbounds {}*, {}** %24, i64 16, !dbg !230
  %26 = bitcast {}** %25 to i8**, !dbg !230
  %27 = load i8*, i8** %26, align 8, !dbg !230
  %28 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %24, i64 80, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 125882009936336 to {}*) to {} addrspace(10)*)), !dbg !230
  %29 = bitcast {} addrspace(10)* %28 to [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(10)*, !dbg !230
  %30 = addrspacecast [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(10)* %29 to [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)*, !dbg !230
  %31 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %30, i64 0, i32 0, !dbg !230
  %32 = load { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %18, align 8, !dbg !230
  %33 = load { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %"'ipc", align 8, !dbg !230
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %32, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %31, align 8, !dbg !230
  %34 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %30, i64 0, i32 1, !dbg !230
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %33, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %34, align 8, !dbg !230
  %35 = extractvalue { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %32, 0, !dbg !230
  %36 = extractvalue { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %33, 0, !dbg !230
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %28, {} addrspace(10)* %35, {} addrspace(10)* %36), !dbg !230
  %37 = call {}*** @julia.get_pgcstack()
  %ptls_field3.i = getelementptr inbounds {}**, {}*** %37, i64 2
  %38 = bitcast {}*** %ptls_field3.i to i64***
  %ptls_load45.i = load i64**, i64*** %38, align 8, !tbaa !18
  %39 = getelementptr inbounds i64*, i64** %ptls_load45.i, i64 2
  %safepoint.i = load i64*, i64** %39, align 8, !tbaa !22
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i), !dbg !264
  fence syncscope("singlethread") seq_cst
  %.not35 = icmp eq i64 %iv.next, %arraysize, !dbg !267
  br i1 %.not35, label %L67.loopexit, label %L18, !dbg !269

L67.loopexit:                                     ; preds = %L42
  store i64 %arraysize9, i64* %memcpy_refined_dst, align 8, !dbg !243, !tbaa !75, !alias.scope !77, !noalias !244
  store i64 %arraysize, i64* %17, align 8, !dbg !243, !tbaa !75, !alias.scope !77, !noalias !244
  br label %L67, !dbg !270

L67:                                              ; preds = %L67.loopexit, %top
  br label %invertL67, !dbg !270

inverttop:                                        ; preds = %invertL67, %invertL18.preheader
  fence syncscope("singlethread") seq_cst
  fence syncscope("singlethread") seq_cst
  ret void

invertL18.preheader:                              ; preds = %invertL18
  br label %inverttop

invertL18:                                        ; preds = %julia_reverse_11904.exit
  %40 = load double, double* %"value_phi7'de", align 8
  store double 0.000000e+00, double* %"value_phi7'de", align 8
  %41 = load i64, i64* %"iv'ac", align 8
  %42 = icmp eq i64 %41, 0
  %43 = xor i1 %42, true
  %44 = select fast i1 %43, double %40, double 0.000000e+00
  %45 = load double, double* %"'de", align 8
  %46 = fadd fast double %45, %40
  %47 = select fast i1 %42, double %45, double %46
  store double %47, double* %"'de", align 8
  br i1 %42, label %invertL18.preheader, label %incinvertL18

incinvertL18:                                     ; preds = %invertL18
  %48 = load i64, i64* %"iv'ac", align 8
  %49 = add nsw i64 %48, -1
  store i64 %49, i64* %"iv'ac", align 8
  br label %invertL42

invertL42:                                        ; preds = %mergeinvertL18_L67.loopexit, %incinvertL18
  %50 = load double, double* %"'de", align 8, !dbg !271
  store double 0.000000e+00, double* %"'de", align 8, !dbg !271
  %51 = load double, double* %"value_phi7'de", align 8, !dbg !271
  %52 = fadd fast double %51, %50, !dbg !271
  store double %52, double* %"value_phi7'de", align 8, !dbg !271
  %53 = load double, double* %"'de2", align 8, !dbg !271
  %54 = fadd fast double %53, %50, !dbg !271
  store double %54, double* %"'de2", align 8, !dbg !271
  %55 = load i64, i64* %"iv'ac", align 8, !dbg !230
  %_unwrap = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10 to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*, !dbg !230
  %56 = load i64, i64* %"iv'ac", align 8, !dbg !230
  %"'ipc_unwrap" = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa" to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*, !dbg !230
  %57 = bitcast {}*** %8 to {}**, !dbg !230
  %58 = getelementptr inbounds {}*, {}** %57, i64 -14, !dbg !230
  %59 = getelementptr inbounds {}*, {}** %58, i64 16, !dbg !230
  %60 = bitcast {}** %59 to i8**, !dbg !230
  %61 = load i8*, i8** %60, align 8, !dbg !230
  %62 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %58, i64 80, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 125882009936336 to {}*) to {} addrspace(10)*)), !dbg !230
  %63 = bitcast {} addrspace(10)* %62 to [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(10)*, !dbg !230
  %64 = addrspacecast [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(10)* %63 to [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)*, !dbg !230
  %65 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i32 0, !dbg !230
  %66 = load { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %_unwrap, align 8, !dbg !230
  %67 = load { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %"'ipc_unwrap", align 8, !dbg !230
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %66, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %65, align 8, !dbg !230
  %68 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i32 1, !dbg !230
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %67, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %68, align 8, !dbg !230
  %69 = extractvalue { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %66, 0, !dbg !230
  %70 = extractvalue { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %67, 0, !dbg !230
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %62, {} addrspace(10)* %69, {} addrspace(10)* %70), !dbg !230
  %71 = load double, double* %"'de2", align 8, !dbg !230
  store double 0.000000e+00, double* %"'de2", align 8, !dbg !230
  %72 = bitcast {}*** %7 to {}**, !dbg !230
  %73 = getelementptr inbounds {}*, {}** %72, i64 -14, !dbg !230
  %74 = getelementptr inbounds {}*, {}** %73, i64 16, !dbg !230
  %75 = bitcast {}** %74 to i8**, !dbg !230
  %76 = load i8*, i8** %75, align 8, !dbg !230
  %77 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %73, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 125879275533520 to {}*) to {} addrspace(10)*)), !dbg !230
  %78 = bitcast {} addrspace(10)* %77 to [1 x double] addrspace(10)*, !dbg !230
  %79 = addrspacecast [1 x double] addrspace(10)* %78 to [1 x double] addrspace(11)*, !dbg !230
  %80 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %79, i64 0, i32 0, !dbg !230
  store double %71, double addrspace(11)* %80, align 8, !dbg !230
  %81 = bitcast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 40, i8* %81)
  %82 = bitcast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 40, i8* %82)
  %83 = bitcast [1 x i64]* %newstruct.i to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* %83)
  %84 = bitcast [1 x i64]* %newstruct12.i to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* %84)
  %85 = call {}*** @julia.get_pgcstack()
  %ptls_field43.i = getelementptr inbounds {}**, {}*** %85, i64 2
  %86 = bitcast {}*** %ptls_field43.i to i64***
  %ptls_load4445.i = load i64**, i64*** %86, align 8, !tbaa !18
  %87 = getelementptr inbounds i64*, i64** %ptls_load4445.i, i64 2
  %safepoint.i4 = load i64*, i64** %87, align 8, !tbaa !22
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i4), !dbg !273
  fence syncscope("singlethread") seq_cst
  %88 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, !dbg !276
  %89 = call fastcc nonnull {} addrspace(10)* @julia_repr_11919({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(80) %88), !dbg !282
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 125879058330480 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %89), !dbg !282
  %90 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 2, !dbg !283
  %getfield_addr.i = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 0, !dbg !283
  %getfield.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr.i unordered, align 8, !dbg !283, !tbaa !22, !alias.scope !30, !noalias !33, !nonnull !17, !dereferenceable !131, !align !132
  %91 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 3, !dbg !283
  %92 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 1, i32 0, i64 0, i64 0, !dbg !286
  %unbox.unpack.unpack.unpack.i = load i64, i64 addrspace(11)* %92, align 8, !dbg !286, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox.elt46.i = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 1, i32 1, !dbg !286
  %unbox.unpack47.i = load i64, i64 addrspace(11)* %unbox.elt46.i, align 8, !dbg !286, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox2.i = load i64, i64 addrspace(11)* %90, align 8, !dbg !286, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox3.i = load i64, i64 addrspace(11)* %91, align 8, !dbg !286, !tbaa !22, !alias.scope !30, !noalias !33
  %.fca.0.0.0.0.gep33.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 0, i32 0, i64 0, i64 0, !dbg !287
  store i64 %unbox.unpack.unpack.unpack.i, i64* %.fca.0.0.0.0.gep33.i, align 8, !dbg !287, !noalias !288
  %.fca.0.1.gep35.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 0, i32 1, !dbg !287
  store i64 %unbox.unpack47.i, i64* %.fca.0.1.gep35.i, align 8, !dbg !287, !noalias !288
  %.fca.1.gep37.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 1, !dbg !287
  store i64 %unbox2.i, i64* %.fca.1.gep37.i, align 8, !dbg !287, !noalias !288
  %.fca.2.gep39.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 2, !dbg !287
  store {} addrspace(10)* %getfield.i, {} addrspace(10)** %.fca.2.gep39.i, align 8, !dbg !287, !noalias !288
  %.fca.3.gep41.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 3, !dbg !287
  store i64 %unbox3.i, i64* %.fca.3.gep41.i, align 8, !dbg !287, !noalias !288
  %93 = addrspacecast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1 to { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 } addrspace(11)*, !dbg !287
  %94 = call fastcc nonnull {} addrspace(10)* @julia_repr_11915({ { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %93), !dbg !287
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 125879269903824 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %94), !dbg !287
  %95 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, !dbg !291
  %96 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, i32 2, !dbg !291
  %getfield_addr4.i = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %95, i64 0, i32 0, !dbg !291
  %getfield5.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr4.i unordered, align 8, !dbg !291, !tbaa !22, !alias.scope !30, !noalias !33, !nonnull !17, !dereferenceable !131, !align !132
  %97 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, i32 3, !dbg !291
  %98 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, i32 1, i32 0, i64 0, i64 0, !dbg !294
  %unbox6.unpack.unpack.unpack.i = load i64, i64 addrspace(11)* %98, align 8, !dbg !294, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox6.elt51.i = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, i32 1, i32 1, !dbg !294
  %unbox6.unpack52.i = load i64, i64 addrspace(11)* %unbox6.elt51.i, align 8, !dbg !294, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox7.i = load i64, i64 addrspace(11)* %96, align 8, !dbg !294, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox8.i = load i64, i64 addrspace(11)* %97, align 8, !dbg !294, !tbaa !22, !alias.scope !30, !noalias !33
  %.fca.0.0.0.0.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 0, i32 0, i64 0, i64 0, !dbg !295
  store i64 %unbox6.unpack.unpack.unpack.i, i64* %.fca.0.0.0.0.gep.i, align 8, !dbg !295, !noalias !288
  %.fca.0.1.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 0, i32 1, !dbg !295
  store i64 %unbox6.unpack52.i, i64* %.fca.0.1.gep.i, align 8, !dbg !295, !noalias !288
  %.fca.1.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 1, !dbg !295
  store i64 %unbox7.i, i64* %.fca.1.gep.i, align 8, !dbg !295, !noalias !288
  %.fca.2.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 2, !dbg !295
  store {} addrspace(10)* %getfield5.i, {} addrspace(10)** %.fca.2.gep.i, align 8, !dbg !295, !noalias !288
  %.fca.3.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 3, !dbg !295
  store i64 %unbox8.i, i64* %.fca.3.gep.i, align 8, !dbg !295, !noalias !288
  %99 = addrspacecast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2 to { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 } addrspace(11)*, !dbg !295
  %100 = call fastcc nonnull {} addrspace(10)* @julia_repr_11915({ { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %99), !dbg !295
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 8 addrspacecast ({}* inttoptr (i64 125881971620392 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %100), !dbg !295
  %memcpy_refined_dst.i = getelementptr inbounds [1 x i64], [1 x i64]* %newstruct.i, i64 0, i64 0, !dbg !296
  store i64 1, i64* %memcpy_refined_dst.i, align 8, !dbg !296, !tbaa !75, !alias.scope !77, !noalias !301
  %bitcast.i = load i64, i64 addrspace(11)* %98, align 8, !dbg !302, !tbaa !22, !alias.scope !30, !noalias !33
  %.not.i = icmp eq i64 %bitcast.i, 0, !dbg !310
  br i1 %.not.i, label %L39.i, label %L42.i, !dbg !312

L39.i:                                            ; preds = %invertL42
  %101 = addrspacecast [1 x i64]* %newstruct.i to [1 x i64] addrspace(11)*, !dbg !312
  call fastcc void @julia_throw_boundserror_11911({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %95, [1 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %101) #22, !dbg !312
  unreachable, !dbg !312

L42.i:                                            ; preds = %invertL42
  %getfield10.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr4.i unordered, align 8, !dbg !313, !tbaa !22, !alias.scope !30, !noalias !33, !nonnull !17, !dereferenceable !131, !align !132
  %unbox11.i = load i64, i64 addrspace(11)* %96, align 8, !dbg !315, !tbaa !22, !alias.scope !30, !noalias !33
  %102 = addrspacecast {} addrspace(10)* %getfield10.i to double addrspace(13)* addrspace(11)*, !dbg !317
  %arrayptr56.i = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %102, align 16, !dbg !317, !tbaa !22, !alias.scope !319, !noalias !33, !nonnull !17
  %103 = getelementptr inbounds double, double addrspace(13)* %arrayptr56.i, i64 %unbox11.i, !dbg !317
  %arrayref.i = load double, double addrspace(13)* %103, align 8, !dbg !317, !tbaa !138, !alias.scope !141, !noalias !142
  %memcpy_refined_dst13.i = getelementptr inbounds [1 x i64], [1 x i64]* %newstruct12.i, i64 0, i64 0, !dbg !296
  store i64 1, i64* %memcpy_refined_dst13.i, align 8, !dbg !296, !tbaa !75, !alias.scope !77, !noalias !301
  %bitcast14.i = load i64, i64 addrspace(11)* %92, align 8, !dbg !302, !tbaa !22, !alias.scope !30, !noalias !33
  %.not57.i = icmp eq i64 %bitcast14.i, 0, !dbg !310
  br i1 %.not57.i, label %L60.i, label %julia_reverse_11904.exit, !dbg !312

L60.i:                                            ; preds = %L42.i
  %104 = addrspacecast [1 x i64]* %newstruct12.i to [1 x i64] addrspace(11)*, !dbg !312
  call fastcc void @julia_throw_boundserror_11911({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(80) %88, [1 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %104) #22, !dbg !312
  unreachable, !dbg !312

julia_reverse_11904.exit:                         ; preds = %L42.i
  %getfield16.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr.i unordered, align 8, !dbg !313, !tbaa !22, !alias.scope !30, !noalias !33, !nonnull !17, !dereferenceable !131, !align !132
  %105 = addrspacecast {} addrspace(10)* %getfield16.i to double addrspace(13)* addrspace(11)*, !dbg !317
  %arrayptr1958.i = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %105, align 16, !dbg !317, !tbaa !22, !alias.scope !319, !noalias !33, !nonnull !17
  %unbox17.i = load i64, i64 addrspace(11)* %90, align 8, !dbg !315, !tbaa !22, !alias.scope !30, !noalias !33
  %106 = getelementptr inbounds double, double addrspace(13)* %arrayptr1958.i, i64 %unbox17.i, !dbg !317
  %arrayref20.i = load double, double addrspace(13)* %106, align 8, !dbg !317, !tbaa !138, !alias.scope !141, !noalias !142
  %107 = fmul double %arrayref20.i, 2.000000e+00, !dbg !320
  %108 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %79, i64 0, i64 0, !dbg !326
  %unbox21.i = load double, double addrspace(11)* %108, align 8, !dbg !327, !tbaa !22, !alias.scope !30, !noalias !33
  %109 = fmul double %107, %unbox21.i, !dbg !327
  %110 = fadd double %arrayref.i, %109, !dbg !328
  store double %110, double addrspace(13)* %103, align 8, !dbg !330, !tbaa !138, !alias.scope !141, !noalias !334
  %111 = call fastcc nonnull {} addrspace(10)* @julia_repr_11919({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %95), !dbg !335
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 125879058432336 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %111), !dbg !335
  %112 = call fastcc nonnull {} addrspace(10)* @julia_repr_11913([1 x double] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %79), !dbg !336
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 128 addrspacecast ({}* inttoptr (i64 125878243333248 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %112), !dbg !336
  %113 = bitcast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1 to i8*, !dbg !337
  call void @llvm.lifetime.end.p0i8(i64 40, i8* %113), !dbg !337
  %114 = bitcast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2 to i8*, !dbg !337
  call void @llvm.lifetime.end.p0i8(i64 40, i8* %114), !dbg !337
  %115 = bitcast [1 x i64]* %newstruct.i to i8*, !dbg !337
  call void @llvm.lifetime.end.p0i8(i64 8, i8* %115), !dbg !337
  %116 = bitcast [1 x i64]* %newstruct12.i to i8*, !dbg !337
  call void @llvm.lifetime.end.p0i8(i64 8, i8* %116), !dbg !337
  br label %invertL18

invertL67.loopexit:                               ; preds = %invertL67
  %_unwrap3 = add nsw i64 %arraysize, -1
  br label %mergeinvertL18_L67.loopexit

mergeinvertL18_L67.loopexit:                      ; preds = %invertL67.loopexit
  store i64 %_unwrap3, i64* %"iv'ac", align 8
  br label %invertL42

invertL67:                                        ; preds = %L67
  store double %differeturn, double* %"value_phi23'de", align 8
  %117 = load double, double* %"value_phi23'de", align 8
  store double 0.000000e+00, double* %"value_phi23'de", align 8
  %118 = xor i1 %.not, true
  %119 = select fast i1 %118, double %117, double 0.000000e+00
  %120 = load double, double* %"'de", align 8
  %121 = fadd fast double %120, %117
  %122 = select fast i1 %.not, double %120, double %121
  store double %122, double* %"'de", align 8
  br i1 %.not, label %inverttop, label %invertL67.loopexit
}

b.val = [4.0]
(b.val.indices, b.val.offset1, b.val.parent, b.val.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [1.0 4.0], 1)
(b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [0.0 0.0], 1)
b.dval = [8.0]
out = Active{Float64}(1.0)
b.val = [4.0]
(b.val.indices, b.val.offset1, b.val.parent, b.val.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [1.0 4.0], 1)
(b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [0.0 8.0], 1)
b.dval = [16.0]
out = Active{Float64}(1.0)
wsmoses commented 1 month ago

b is overwritten from fwd to reverse, as specified in the config.

Saving the original b fixes it.

EnzymeRules.overwritten(config) = (false, true)
EnzymeRules.overwritten(config) = (false, true)
b.val = [4.0]
(b.val.indices, b.val.offset1, b.val.parent, b.val.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [1.0 4.0], 1)
(b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [0.0 0.0], 1)
b.dval = [8.0]
out = Active{Float64}(1.0)
b.val = [1.0]
(b.val.indices, b.val.offset1, b.val.parent, b.val.stride1) = ((Base.Slice(Base.OneTo(1)), 1), 0, [1.0 4.0], 1)
(b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1) = ((Base.Slice(Base.OneTo(1)), 1), 0, [0.0 8.0], 1)
b.dval = [2.0]
out = Active{Float64}(1.0)
db = [2.0 8.0]
wmoses@beast:~/git/Enzyme.jl ((HEAD detached at origin/main)) $ cat la.jl 
using Enzyme, LinearAlgebra
using EnzymeCore: EnzymeRules

Enzyme.API.printall!(true)

function _mul!(b)
    return b[1]*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      b::Duplicated)
    @show EnzymeRules.overwritten(config)
    tape = b
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(b.val) : nothing, nothing, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape,  b::Duplicated)
    b = tape
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += 2 * b.val[1] * out.val
    @show b.dval
    @show out
    return (nothing,)
end

@inline function f2(b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = _mul!(@view(b[:, i]))
        s += out
    end
    return s
end

Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

f2(b)
autodiff(Reverse, f2, Active, Duplicated(b, fill!(db, 0)))

@show db
# What Enzyme returns db = [0  16]
# What the correct answer is db = [2 8]
ptiede commented 1 month ago

Wait isn't the original code still a problem? This still gives me a zero in the first element

function _mul!(A, b)
    return A[1]*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      A::Const{<:Matrix}, b::Duplicated)
    tape = (A, b)
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(A.val, b.val) : nothing, nothing, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape, A::Const{<:Matrix}, b::Duplicated)

    @show EnzymeRules.overwritten(config)
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += real(A.val'[1] * out.val)
    @show b.dval
    @show out
    @show A.val
    return (nothing, nothing)
end

@inline function f2(A, b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = similar(A, size(A, 1))
        s = _mul!(A, @view(b[:, i]))
        @inbounds s += sum(abs, out)
    end
    return s
end

Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

A = ones(Float64, Nx, Nx)

f2(A, b)
autodiff(Reverse, f2, Active, Const(A), Duplicated(b, fill!(db, 0)))
# db = [0.0, 1.0]
wsmoses commented 1 month ago

there's some slight bugs in that (specifically you dont unwrap the tape in the reverse pass, and also you say s = mul!)

with those fixed, it works

using Enzyme, LinearAlgebra

function _mul!(A, b)
    return A[1]*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      A::Const{<:Matrix}, b::Duplicated)
    tape = (A, b)
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(A.val, b.val) : nothing, nothing, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape, A::Const{<:Matrix}, b::Duplicated)
    # You need this here
    (A, b) = tape

    @show EnzymeRules.overwritten(config)
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += real(A.val'[1] * out.val)
    @show b.dval
    @show out
    @show A.val
    return (nothing, nothing)
end

@inline function f2(A, b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = similar(A, size(A, 1))
        s += _mul!(A, @view(b[:, i]))
    end
    return s
end

Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

A = ones(Float64, Nx, Nx)

f2(A, b)
autodiff(Reverse, f2, Active, Const(A), Duplicated(b, fill!(db, 0)))
# db = [0.0, 1.0]
ptiede commented 1 month ago

For instance this still gives me something weird

function _mul!(A, b)
    return A*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      A::Const, b::Duplicated)
    tape = (copy(A.val), copy(b.val))
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(A.val, b.val) : nothing, nothing, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape, A::Const, b::Duplicated)

    b1 = tape[2]
    A = tape[1]
    @show EnzymeRules.overwritten(config)
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += (A * out.val)
    @show b.dval
    @show out
    @show A
    return (nothing, nothing)
end

@inline function f2(A, b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = _mul!(A, @view(b[:, i]))
        @inbounds s += out
    end
    return s
end

Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

f2(1.0, b)
autodiff(Reverse, f2, Active, Const(1.0), Duplicated(b, fill!(db, 0)))
wsmoses commented 1 month ago

you can't do copy A.val, which creates a different copy of b and thus it won't point to the same reference as b to be updated in place

ptiede commented 1 month ago

Ok I need help understanding the overwritten stuff. So, in fwd pass I should check if b is overwritten. If it is, I need to store the forward pass b in the tape and use that in the reverse rule instead of the b passed to the reverse rule? In other rules e.g., here I see you copy the value in the forward pass. When should I copy and not?

wsmoses commented 1 month ago
using Enzyme, LinearAlgebra

vl = [2, 2]
tot = sum(vl)
rg = [1:2, 3:4]
Nx = 4

iminds = reshape([CartesianIndex(i) for i in 1:2], :)
visinds = [collect(rg[i]) for i in eachindex(rg)]
Bs = Dict((iminds[i]=> ones(ComplexF64, vl[i], Nx*Nx) for i in eachindex(vl)))

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    tape = (out, b)
    return EnzymeRules.AugmentedReturn(nothing, nothing, tape)
end

@noinline function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
                                       ::Const{typeof(_mul!)},
                                       ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
                                       b::Duplicated)

    (out, b) = tape

    b.dval .+= real.(A.val' * out.dval)
    out.dval .= 0
    return (nothing, nothing, nothing)
end

@inline function f(Bs, visinds, iminds, tot, x)
    out = similar(x, Complex{eltype(x)}, tot)
    for i in eachindex(iminds, visinds)
        imind = iminds[i]
        visind = visinds[i]
        _mul!(@view(out[visind]), Bs[imind], reshape(@view(x[:, :, imind]), :))
    end
    return sum(abs2, out)
end

@inline function f2(Bs, visinds, iminds, tot, x)
    out = similar(x, Complex{eltype(x)}, tot)
    for i in eachindex(iminds, visinds)
        imind = iminds[i]
        visind = visinds[i]
        mul!(@view(out[visind]), Bs[imind], reshape(@view(x[:, :, imind]), :))
    end
    return sum(abs2, out)
end

x = ones(Nx, Nx, 2)
dx = zero(x)

f(Bs, visinds, iminds, tot, x)
autodiff(set_runtime_activity(Reverse), f, Active, Const(Bs), Const(visinds), Const(iminds), Const(tot), Duplicated(x, fill!(dx, 0)))
@show dx

x = ones(Nx, Nx, 2)
dx = zero(x)

f2(Bs, visinds, iminds, tot, x)
autodiff(set_runtime_activity(Reverse), f2, Active, Const(Bs), Const(visinds), Const(iminds), Const(tot), Duplicated(x, fill!(dx, 0)))
@show dx

This version of your original code works successfully^

ptiede commented 1 month ago

Thanks! I think I am confused when to store a copy in the tape and when to store the overwritten arguments in augmented_forward

wsmoses commented 1 month ago

yeah okay the more I think about this, we should probably force enzyme to automatically do the nice thing here for top level pointers.

Nevertheless, the above should be a good workaround.

ptiede commented 1 month ago

Ok no problem! I'll implement the fix in my code. Thanks for all the help.