EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
458 stars 65 forks source link

Enzyme doesn't work for `AdvancedVI` Part V: Type stable restructure fails #1638

Closed Red-Portal closed 4 months ago

Red-Portal commented 4 months ago

Hi, here we go again. I tried to resolve the previous type instability issue by forcing the return type. This, however, results in a build error. I have a nice MWE for this one.


using Enzyme, Optimisers, Functors, Distributions, LinearAlgebra, SimpleUnPack

struct MvLocationScale{
    S, D <: ContinuousDistribution, L, E
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
    scale_eps::E
end

Base.length(q::MvLocationScale) = length(q.location)

@functor MvLocationScale (location, scale)

struct RestructureMeanField{S <: Diagonal, D, L}
    model::MvLocationScale{S, D, L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
    n_dims   = div(length(flat), 2)
    location = first(flat, n_dims)
    scale    = Diagonal(last(flat, n_dims))
    MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
end

function Optimisers.destructure(
    q::MvLocationScale{<:Diagonal, D, L}
) where {D, L}
    @unpack location, scale, dist = q
    flat = vcat(location, diag(scale))
    flat, RestructureMeanField(q)
end

restructure_ad_forward(re, params) = re(params)::typeof(re.model)

function f(params, aux)
    @unpack restructure = aux
    q = restructure_ad_forward(restructure, params)
    sum(q.location)
end

function main()
    d = 10
    m = zeros(d)
    C = Diagonal(ones(d))
    q = MvLocationScale(m, C, Normal(), 1e-5)

    params, re = Optimisers.destructure(q)

    aux = (
        restructure = re,
    )

    display(f(params, aux))

    x = ones(length(params))
    ∇x = zeros(length(params))
    Enzyme.API.runtimeActivity!(true)
    Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(x, ∇x), Enzyme.Const(aux))
    ∇x
end

The error is as follows:

ERROR: Enzyme execution failed.
Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140112659606240" "enzymejl_parmtype_ref"="1" double @preprocess_julia_f_4431_inner.1({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="140112588571488" "enzymejl_parmtype_ref"="2" %0, [1 x [1 x {} addrspace(10)*]] "enzyme_type"="{}" "enzymejl_parmtype"="140110628297552" "enzymejl_parmtype_ref"="1" %1) local_unnamed_addr #22 !dbg !477 {
entry:
  %2 = alloca [1 x [1 x {} addrspace(10)*]], align 8, !dbg !478, !enzyme_inactive !16, !enzyme_type !203
  %.fca.0.0.extract = extractvalue [1 x [1 x {} addrspace(10)*]] %1, 0, 0, !dbg !478
  %.fca.0.0.gep = getelementptr inbounds [1 x [1 x {} addrspace(10)*]], [1 x [1 x {} addrspace(10)*]]* %2, i64 0, i64 0, i64 0, !dbg !478
  store {} addrspace(10)* %.fca.0.0.extract, {} addrspace(10)** %.fca.0.0.gep, align 8, !dbg !478, !noalias !479
  %3 = call {}*** @julia.get_pgcstack() #23
  %ptls_field.i3 = getelementptr inbounds {}**, {}*** %3, i64 2
  %4 = bitcast {}*** %ptls_field.i3 to i64***
  %ptls_load.i45 = load i64**, i64*** %4, align 8, !tbaa !17
  %5 = getelementptr inbounds i64*, i64** %ptls_load.i45, i64 2
  %safepoint.i = load i64*, i64** %5, align 8, !tbaa !21
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i) #23, !dbg !482
  fence syncscope("singlethread") seq_cst
  %6 = getelementptr inbounds [1 x [1 x {} addrspace(10)*]], [1 x [1 x {} addrspace(10)*]]* %2, i64 0, i64 0, !dbg !484
  %7 = addrspacecast [1 x {} addrspace(10)*]* %6 to [1 x {} addrspace(10)*] addrspace(11)*, !dbg !484
  %8 = call fastcc nonnull {} addrspace(10)* @julia_RestructureMeanField_4440([1 x {} addrspace(10)*] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %7, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %0) #23, !dbg !486
  %getfield.i = load atomic {} addrspace(10)*, {} addrspace(10)** %.fca.0.0.gep unordered, align 8, !dbg !488, !alias.scope !193, !noalias !194, !nonnull !16
  %typeof.i = call "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @julia.typeof({} addrspace(10)* nonnull %getfield.i) #24, !dbg !486
  call void @ijl_typeassert({} addrspace(10)* nonnull %8, {} addrspace(10)* nonnull %typeof.i) #23, !dbg !486
  %9 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* noundef nonnull @jl_f_getfield, {} addrspace(10)* noundef null, {} addrspace(10)* nonnull %8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140111116898992 to {}*) to {} addrspace(10)*)) #25, !dbg !490
  %10 = addrspacecast {} addrspace(10)* %9 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !492
  %arraylen_ptr.i = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %10, i64 0, i32 1, !dbg !492
  %arraylen.i = load i64, i64 addrspace(11)* %arraylen_ptr.i, align 8, !dbg !492, !tbaa !54, !range !57, !alias.scope !58, !noalias !59
  switch i64 %arraylen.i, label %L19.i [
    i64 0, label %julia_f_4431_inner.exit
    i64 1, label %L17.i
  ], !dbg !505

L17.i:                                            ; preds = %entry
  %11 = addrspacecast {} addrspace(10)* %9 to double addrspace(13)* addrspace(11)*, !dbg !506
  %arrayptr.i7 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %11, align 8, !dbg !506, !tbaa !251, !alias.scope !508, !noalias !59, !nonnull !16
  %arrayref.i = load double, double addrspace(13)* %arrayptr.i7, align 8, !dbg !506, !tbaa !254, !alias.scope !32, !noalias !256
  br label %julia_f_4431_inner.exit, !dbg !509

L19.i:                                            ; preds = %entry
  %12 = icmp ugt i64 %arraylen.i, 15, !dbg !510
  br i1 %12, label %L35.i, label %L21.i, !dbg !512

L21.i:                                            ; preds = %L19.i
  %13 = addrspacecast {} addrspace(10)* %9 to double addrspace(13)* addrspace(11)*, !dbg !513
  %arrayptr3.i8 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %13, align 8, !dbg !513, !tbaa !251, !alias.scope !508, !noalias !59, !nonnull !16
  %arrayref4.i = load double, double addrspace(13)* %arrayptr3.i8, align 8, !dbg !513, !tbaa !254, !alias.scope !32, !noalias !256
  %14 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 1, !dbg !515
  %arrayref7.i = load double, double addrspace(13)* %14, align 8, !dbg !515, !tbaa !254, !alias.scope !32, !noalias !256
  %15 = fadd double %arrayref4.i, %arrayref7.i, !dbg !517
  %.not910 = icmp ugt i64 %arraylen.i, 2, !dbg !520
  br i1 %.not910, label %L30.i.lr.ph, label %julia_f_4431_inner.exit, !dbg !522

L30.i.lr.ph:                                      ; preds = %L21.i
  %16 = add nsw i64 %arraylen.i, -2, !dbg !522
  %17 = add nsw i64 %arraylen.i, -3, !dbg !522
  %xtraiter = and i64 %16, 7, !dbg !522
  %18 = icmp ult i64 %17, 7, !dbg !522
  br i1 %18, label %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa, label %L30.i.lr.ph.new, !dbg !522

L30.i.lr.ph.new:                                  ; preds = %L30.i.lr.ph
  %unroll_iter = and i64 %16, -8, !dbg !522
  br label %L30.i, !dbg !522

L30.i:                                            ; preds = %L30.i, %L30.i.lr.ph.new
  %iv = phi i64 [ %iv.next, %L30.i ], [ 0, %L30.i.lr.ph.new ]
  %value_phi8.i11 = phi double [ %15, %L30.i.lr.ph.new ], [ %45, %L30.i ]
  %19 = shl nuw i64 %iv, 3, !dbg !523
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !523
  %20 = shl i64 %iv, 3, !dbg !523
  %21 = add nuw nsw i64 %20, 2, !dbg !523
  %22 = or i64 %21, 1, !dbg !523
  %23 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %21, !dbg !525
  %arrayref13.i = load double, double addrspace(13)* %23, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %24 = fadd double %value_phi8.i11, %arrayref13.i, !dbg !526
  %25 = add nuw nsw i64 %21, 2, !dbg !523
  %26 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %22, !dbg !525
  %arrayref13.i.1 = load double, double addrspace(13)* %26, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %27 = fadd double %24, %arrayref13.i.1, !dbg !526
  %28 = add nuw nsw i64 %21, 3, !dbg !523
  %29 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %25, !dbg !525
  %arrayref13.i.2 = load double, double addrspace(13)* %29, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %30 = fadd double %27, %arrayref13.i.2, !dbg !526
  %31 = add nuw nsw i64 %21, 4, !dbg !523
  %32 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %28, !dbg !525
  %arrayref13.i.3 = load double, double addrspace(13)* %32, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %33 = fadd double %30, %arrayref13.i.3, !dbg !526
  %34 = add nuw nsw i64 %21, 5, !dbg !523
  %35 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %31, !dbg !525
  %arrayref13.i.4 = load double, double addrspace(13)* %35, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %36 = fadd double %33, %arrayref13.i.4, !dbg !526
  %37 = add nuw nsw i64 %21, 6, !dbg !523
  %38 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %34, !dbg !525
  %arrayref13.i.5 = load double, double addrspace(13)* %38, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %39 = fadd double %36, %arrayref13.i.5, !dbg !526
  %40 = add nuw nsw i64 %21, 7, !dbg !523
  %41 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %37, !dbg !525
  %arrayref13.i.6 = load double, double addrspace(13)* %41, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %42 = fadd double %39, %arrayref13.i.6, !dbg !526
  %43 = add nuw nsw i64 %21, 8, !dbg !523
  %44 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %40, !dbg !525
  %arrayref13.i.7 = load double, double addrspace(13)* %44, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %45 = fadd double %42, %arrayref13.i.7, !dbg !526
  %niter.next.7 = add i64 %19, 8, !dbg !522
  %niter.ncmp.7.not = icmp eq i64 %niter.next.7, %unroll_iter, !dbg !522
  br i1 %niter.ncmp.7.not, label %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa.loopexit, label %L30.i, !dbg !522

L35.i:                                            ; preds = %L19.i
  %46 = call fastcc double @julia_mapreduce_impl_4435({} addrspace(10)* nocapture nofree noundef nonnull readonly align 16 dereferenceable(40) %9, i64 noundef signext 1, i64 signext %arraylen.i) #23, !dbg !529
  br label %julia_f_4431_inner.exit, !dbg !531

L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa.loopexit: ; preds = %L30.i
  br label %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa, !dbg !522

L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa: ; preds = %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa.loopexit, %L30.i.lr.ph
  %.lcssa.ph = phi double [ undef, %L30.i.lr.ph ], [ %45, %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa.loopexit ]
  %value_phi9.i12.unr = phi i64 [ 2, %L30.i.lr.ph ], [ %43, %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa.loopexit ]
  %value_phi8.i11.unr = phi double [ %15, %L30.i.lr.ph ], [ %45, %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa.loopexit ]
  %lcmp.mod.not = icmp eq i64 %xtraiter, 0, !dbg !522
  br i1 %lcmp.mod.not, label %julia_f_4431_inner.exit, label %L30.i.epil.preheader, !dbg !522

L30.i.epil.preheader:                             ; preds = %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa
  br label %L30.i.epil, !dbg !522

L30.i.epil:                                       ; preds = %L30.i.epil.preheader, %L30.i.epil
  %iv1 = phi i64 [ 0, %L30.i.epil.preheader ], [ %iv.next2, %L30.i.epil ]
  %value_phi8.i11.epil = phi double [ %50, %L30.i.epil ], [ %value_phi8.i11.unr, %L30.i.epil.preheader ]
  %47 = add nuw nsw i64 %value_phi9.i12.unr, %iv1, !dbg !523
  %iv.next2 = add nuw nsw i64 %iv1, 1, !dbg !523
  %48 = add nuw nsw i64 %47, 1, !dbg !523
  %49 = getelementptr inbounds double, double addrspace(13)* %arrayptr3.i8, i64 %47, !dbg !525
  %arrayref13.i.epil = load double, double addrspace(13)* %49, align 8, !dbg !525, !tbaa !254, !alias.scope !32, !noalias !256
  %50 = fadd double %value_phi8.i11.epil, %arrayref13.i.epil, !dbg !526
  %epil.iter.cmp.not = icmp eq i64 %iv.next2, %xtraiter, !dbg !522
  br i1 %epil.iter.cmp.not, label %julia_f_4431_inner.exit.loopexit, label %L30.i.epil, !dbg !522, !llvm.loop !532

julia_f_4431_inner.exit.loopexit:                 ; preds = %L30.i.epil
  br label %julia_f_4431_inner.exit, !dbg !478

julia_f_4431_inner.exit:                          ; preds = %julia_f_4431_inner.exit.loopexit, %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa, %L35.i, %L21.i, %L17.i, %entry
  %value_phi.i = phi double [ %arrayref.i, %L17.i ], [ %46, %L35.i ], [ 0.000000e+00, %entry ], [ %15, %L21.i ], [ %.lcssa.ph, %L24.i.julia_f_4431_inner.exit.loopexit_crit_edge.unr-lcssa ], [ %50, %julia_f_4431_inner.exit.loopexit ]
  ret double %value_phi.i, !dbg !478
}

No augmented forward pass found for ijl_typeassert
 at context:   call void @ijl_typeassert({} addrspace(10)* nonnull %8, {} addrspace(10)* nonnull %typeof.i) #23, !dbg !35

Stacktrace:
 [1] restructure_ad_forward
   @ ./REPL[43]:1
 [2] f
   @ ./REPL[44]:3
 [3] f
   @ ./REPL[44]:0

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/qd8AI/src/compiler.jl:1620
  [2] restructure_ad_forward
    @ ./REPL[43]:1 [inlined]
  [3] f
    @ ./REPL[44]:3 [inlined]
  [4] f
    @ ./REPL[44]:0 [inlined]
  [5] diffejulia_f_4431_inner_1wrap
    @ ./REPL[44]:0
  [6] macro expansion
    @ ~/.julia/packages/Enzyme/qd8AI/src/compiler.jl:6606 [inlined]
  [7] enzyme_call
    @ ~/.julia/packages/Enzyme/qd8AI/src/compiler.jl:6207 [inlined]
  [8] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/qd8AI/src/compiler.jl:6084 [inlined]
  [9] autodiff
    @ ~/.julia/packages/Enzyme/qd8AI/src/Enzyme.jl:309 [inlined]
 [10] autodiff
    @ ~/.julia/packages/Enzyme/qd8AI/src/Enzyme.jl:321 [inlined]
 [11] main()
    @ Main ./REPL[38]:18
 [12] top-level scope
    @ REPL[45]:1
wsmoses commented 4 months ago

Should be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1639 please reopen if it persists