Enzyme compilation failed for cmpxchg #655

Closed jeremiedb closed 10 months ago

jeremiedb commented 1 year ago

The following Julia code attempts to differentiate a convolution from NNlib but fails:

using Enzyme
using NNlib

w = randn(Float32, 3, 3, 5, 7);
dw = zero(w)
loss(w, x) = sum(conv(x, w))
x = randn(Float32, (3, 3, 5, 8));
grads = Enzyme.autodiff(loss, Duplicated(w, dw), Const(x));

Results in the following error trace (truncated for lisibility):

ERROR: Enzyme compilation failed.
Current scope:
; Function Attrs: mustprogress noinline uwtable willreturn
define internal fastcc noundef i8 @preprocess_julia__trylock_7199({} addrspace(10)* nonnull align 8 dereferenceable(56) %0, {} addrspace(10)* nofree nonnull align 8 dereferenceable(96) %1) unnamed_addr EnzymeAD/Enzyme#119 !dbg !6737 {
  %2 = call {}*** @julia.get_pgcstack() EnzymeAD/Enzyme#110
  %ptls_field8 = getelementptr inbounds {}**, {}*** %2, i64 2, !dbg !6738


in Mode: ReverseModePrimal
cannot handle unknown instruction
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !142, !tbaa !146
vtjnash commented 1 year ago

Should trylock have a rule also (which is where it appears this come from)?

wsmoses commented 1 year ago

Likely its reasonable to mark it inactive, but I'd want to see what's calling trylock since it might make more sense to have the outer code that calls trylock be inactive.

@jeremiedb if you can see where its being called?

jeremiedb commented 1 year ago

The complete error trace for the above conv call is:

julia> grads = Enzyme.autodiff(Reverse, loss, Duplicated(w, dw), Const(x));
ERROR: Enzyme compilation failed.
Current scope:
; Function Attrs: mustprogress noinline uwtable willreturn
define internal fastcc noundef i8 @preprocess_julia__trylock_11344({} addrspace(10)* nonnull align 8 dereferenceable(56) %0, {} addrspace(10)* nofree nonnull align 8 dereferenceable(96) %1) unnamed_addr #127 !dbg !6627 {
  %2 = call {}*** @julia.get_pgcstack() #128
  %ptls_field8 = getelementptr inbounds {}**, {}*** %2, i64 2, !dbg !6628
  %3 = bitcast {}*** %ptls_field8 to i32**, !dbg !6628
  %ptls_load910 = load i32*, i32** %3, align 8, !dbg !6628, !tbaa !2355
  %4 = getelementptr inbounds i32, i32* %ptls_load910, i64 8, !dbg !6628
  %5 = load i32, i32* %4, align 4, !dbg !6628
  %6 = add i32 %5, 1, !dbg !6628
  store i32 %6, i32* %4, align 4, !dbg !6628
  %7 = bitcast {} addrspace(10)* %0 to i8 addrspace(10)*, !dbg !6630
  %8 = addrspacecast i8 addrspace(10)* %7 to i8 addrspace(11)*, !dbg !6630
  %9 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 12, !dbg !6630
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !6630, !tbaa !266
  %11 = extractvalue { i8, i1 } %10, 1, !dbg !6630
  br i1 %11, label %L6, label %L9, !dbg !6631

common.ret:                                       ; preds = %L16, %L9, %L6
  %common.ret.op = phi i8 [ 1, %L6 ], [ 0, %L9 ], [ 0, %L16 ]
  ret i8 %common.ret.op, !dbg !6632

L6:                                               ; preds = %top
  %12 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 8, !dbg !6633
  %13 = bitcast i8 addrspace(11)* %12 to i32 addrspace(11)*, !dbg !6633
  store i32 1, i32 addrspace(11)* %13, align 8, !dbg !6633, !tbaa !266
  %14 = bitcast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(10)*, !dbg !6635
  store atomic {} addrspace(10)* %1, {} addrspace(10)* addrspace(10)* %14 release, align 8, !dbg !6635, !tbaa !266
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* nofree noundef nonnull %0, {} addrspace(10)* nofree nonnull %1) #129, !dbg !6635
  br label %common.ret

L9:                                               ; preds = %top
  %ptls_load41314 = load i32*, i32** %3, align 8, !dbg !6637, !tbaa !2355
  %15 = getelementptr inbounds i32, i32* %ptls_load41314, i64 8, !dbg !6637
  %16 = load i32, i32* %15, align 4, !dbg !6637
  %17 = add i32 %16, -1, !dbg !6637
  %18 = icmp eq i32 %16, 0, !dbg !6637
  %19 = select i1 %18, i32 0, i32 %17, !dbg !6637
  store i32 %19, i32* %15, align 4, !dbg !6637
  %20 = load atomic i32, i32* inttoptr (i64 140730518437480 to i32*) monotonic, align 8, !dbg !6639, !tbaa !600
  %.not = icmp eq i32 %20, 0, !dbg !6640
  br i1 %.not, label %common.ret, label %L16, !dbg !6639

L16:                                              ; preds = %L9
  call void @jl_gc_run_pending_finalizers(i64 noundef 0) #128, !dbg !6643
  br label %common.ret, !dbg !6643

; Function Attrs: mustprogress noinline uwtable willreturn
define internal fastcc noundef { {} addrspace(10)*, i8 } @fakeaugmented_julia__trylock_11344({} addrspace(10)* nonnull align 8 dereferenceable(56) %0, {} addrspace(10)* %"'", {} addrspace(10)* nofree nonnull align 8 dereferenceable(96) %1) unnamed_addr #127 !dbg !6644 {
  %2 = call {}*** @julia.get_pgcstack() #128
  %ptls_field8 = getelementptr inbounds {}**, {}*** %2, i64 2, !dbg !6645
  %3 = bitcast {}*** %ptls_field8 to i32**, !dbg !6645
  %ptls_load910 = load i32*, i32** %3, align 8, !dbg !6645, !tbaa !2355
  %"ptls_load910'il_phi" = phi i32* , !dbg !6645
  %4 = getelementptr inbounds i32, i32* %ptls_load910, i64 8, !dbg !6645
  %5 = load i32, i32* %4, align 4, !dbg !6645
  %"'il_phi" = phi i32 , !dbg !6645
  %6 = add i32 %5, 1, !dbg !6645
  store i32 %6, i32* %4, align 4, !dbg !6645
  %7 = bitcast {} addrspace(10)* %0 to i8 addrspace(10)*, !dbg !6647
  %8 = addrspacecast i8 addrspace(10)* %7 to i8 addrspace(11)*, !dbg !6647
  %9 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 12, !dbg !6647
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !6647, !tbaa !266
  %11 = extractvalue { i8, i1 } %10, 1, !dbg !6647
  br i1 %11, label %L6, label %L9, !dbg !6648

common.ret:                                       ; preds = %L16, %L9, %L6
  %common.ret.op = phi i8 [ 1, %L6 ], [ 0, %L9 ], [ 0, %L16 ]
  %12 = insertvalue { {} addrspace(10)*, i8 } undef, i8 %common.ret.op, 1, !dbg !6649
  ret { {} addrspace(10)*, i8 } %12, !dbg !6649

L6:                                               ; preds = %top
  %13 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 8, !dbg !6650
  %14 = bitcast i8 addrspace(11)* %13 to i32 addrspace(11)*, !dbg !6650
  store i32 1, i32 addrspace(11)* %14, align 8, !dbg !6650, !tbaa !266
  %15 = bitcast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(10)*, !dbg !6652
  store atomic {} addrspace(10)* %1, {} addrspace(10)* addrspace(10)* %15 release, align 8, !dbg !6652, !tbaa !266
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* nofree noundef nonnull %0, {} addrspace(10)* nofree nonnull %1) #129, !dbg !6652
  br label %common.ret

L9:                                               ; preds = %top
  %ptls_load41314 = load i32*, i32** %3, align 8, !dbg !6654, !tbaa !2355
  %"ptls_load41314'il_phi" = phi i32* , !dbg !6654
  %16 = getelementptr inbounds i32, i32* %ptls_load41314, i64 8, !dbg !6654
  %17 = load i32, i32* %16, align 4, !dbg !6654
  %"'il_phi1" = phi i32 , !dbg !6654
  %18 = add i32 %17, -1, !dbg !6654
  %19 = icmp eq i32 %17, 0, !dbg !6654
  %20 = select i1 %19, i32 0, i32 %18, !dbg !6654
  store i32 %20, i32* %16, align 4, !dbg !6654
  %21 = load atomic i32, i32* inttoptr (i64 140730518437480 to i32*) monotonic, align 8, !dbg !6656, !tbaa !600
  %"'il_phi2" = phi i32 , !dbg !6657
  %.not = icmp eq i32 %21, 0, !dbg !6657
  br i1 %.not, label %common.ret, label %L16, !dbg !6656

L16:                                              ; preds = %L9
  call void @jl_gc_run_pending_finalizers(i64 noundef 0) #128, !dbg !6660
  br label %common.ret, !dbg !6660

allocsForInversion:                               ; No predecessors!

in Mode: ReverseModePrimal
cannot handle unknown instruction
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !143, !tbaa !147

 [1] replaceproperty!
   @ .\Base.jl:58
 [2] _trylock
   @ .\lock.jl:82

  [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:4735
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\api.jl:123
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Array{Float32, 4}, Array{Float32, 4}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{typeof(loss), Tuple{Duplicated{Array{Float32, 4}}, Const{Array{Float32, 4}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Tuple{Bool, Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:6195
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Array{Float32, 4}, Array{Float32, 4}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:7446
  [5] _thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:7958 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Array{Float32, 4}, Array{Float32, 4}}}})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:7952
  [7] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:7996
  [8] #s451#163
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:8056 [inlined]
  [9] var"#s451#163"(F::Any, Fn::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ReturnPrimal::Any, ShadowInit::Any, ::Any, #unused#::Type, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler .\none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:582
 [11] thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:8089 [inlined]
 [12] thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:8082 [inlined]
 [13] autodiff
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\Enzyme.jl:197 [inlined]
 [14] autodiff(::EnzymeCore.ReverseMode{false, false}, ::typeof(loss), ::Duplicated{Array{Float32, 4}}, ::Const{Array{Float32, 4}})
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\Enzyme.jl:223
 [15] top-level scope
    @ c:\Users\jerem\OneDrive\github\ADTests.jl\experiments\enzyme\conv.jl:18

And LLVM is:

julia> @code_llvm conv(x, w)
;  @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\conv.jl:50 within `conv`
; Function Attrs: uwtable
define nonnull {}* @julia_conv_12575({}* nonnull align 16 dereferenceable(40) %0, {}* nonnull align 16 dereferenceable(40) %1) #0 {
  %2 = alloca <4 x i64>, align 8
  %tmpcast = bitcast <4 x i64>* %2 to [4 x i64]*
  %3 = alloca <4 x i64>, align 8
  %tmpcast9 = bitcast <4 x i64>* %3 to [4 x i64]*
  %4 = alloca { [2 x i64], [2 x i64], i64, i64, i64, [2 x i64], [4 x i64], [2 x i64], i8 }, align 8
; ┌ @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\conv.jl:54 within `#conv#231`
; │┌ @ array.jl:153 within `size`
; ││┌ @ ntuple.jl:69 within `ntuple`
; │││┌ @ ntuple.jl:74 within `macro expansion`
; ││││┌ @ array.jl:153 within `#108`
; │││││┌ @ array.jl:150 within `size`
        %5 = bitcast {}* %0 to {}**
        %6 = getelementptr inbounds {}*, {}** %5, i64 3
        %7 = bitcast {}** %6 to <4 x i64>*
        %8 = load <4 x i64>, <4 x i64>* %7, align 8
; ││││└└
; ││││ @ ntuple.jl:75 within `macro expansion`
      store <4 x i64> %8, <4 x i64>* %2, align 8
; ││││ @ ntuple.jl:74 within `macro expansion`
; ││││┌ @ array.jl:153 within `#108`
; │││││┌ @ array.jl:150 within `size`
        %9 = bitcast {}* %1 to {}**
        %10 = getelementptr inbounds {}*, {}** %9, i64 3
        %11 = bitcast {}** %10 to <4 x i64>*
        %12 = load <4 x i64>, <4 x i64>* %11, align 8
; ││││└└
; ││││ @ ntuple.jl:75 within `macro expansion`
      store <4 x i64> %12, <4 x i64>* %3, align 8
; │└└└
; │┌ @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\dim_helpers\DenseConvDims.jl:20 within `Type##kw`
    call void @"j_#DenseConvDims#8_12577"({ [2 x i64], [2 x i64], i64, i64, i64, [2 x i64], [4 x i64], [2 x i64], i8 }* noalias nocapture nonnull sret({ [2 x i64], [2 x i64], i64, i64, i64, [2 x i64], [4 x i64], [2 x i64], i8 }) %4, [2 x i64]* nocapture readonly @_j_const1, [2 x i64]* nocapture readonly @_j_const2, [2 x i64]* nocapture readonly @_j_const1, i64 signext 1, i8 zeroext 0, {}* readonly inttoptr (i64 2862949472816 to {}*), [4 x i64]* nocapture readonly %tmpcast, [4 x i64]* nocapture readonly %tmpcast9) #0
; │└
; │ @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\conv.jl:56 within `#conv#231`
; │┌ @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\conv.jl:83 within `conv`
    %13 = call nonnull {}* @"j_#conv#233_12578"({}* nonnull %0, {}* nonnull %1, { [2 x i64], [2 x i64], i64, i64, i64, [2 x i64], [4 x i64], [2 x i64], i8 }* nocapture readonly %4) #0
; └└
  ret {}* %13

I've looking to breakdown NNlib.conv calls in https://github.com/jeremiedb/ADTests.jl/blob/main/experiments/enzyme/conv-debug.jl

I'm yet unclear what instructions is the source of the issue. Two potential candidates could be:

wsmoses commented 1 year ago

I don't think its the GC.@preserve. but the Threads.@sync might be it? @vtjnash @vchuravy ?

vchuravy commented 1 year ago
mutable struct Atomic{T}
    @atomic x::T

function f(x, y)
    @atomic x.x max y
    val = @atomic x.x

@show f(Atomic(0.0), 1.0)

using Enzyme

x = Atomic(2.0)
dx = Atomic(0.0)

autodiff(Reverse, f, ACtive, Duplicated(x, dx), y)
jeremiedb commented 1 year ago

For info, the above f on Atomic results in a different error, which segfault:

PS C:\Users\jerem\OneDrive\github\ADTests.jl> julia --project=@. --threads=1 .\experiments\enzyme\conv-debug.jl
f(Atomic(0.0), 3.0) = 9.0
module: ; ModuleID = 'text'
source_filename = "text"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-w64-mingw32"

; Function Attrs: nofree nosync readnone uwtable
define internal fastcc double @julia_max_1932(double %0, double %1) unnamed_addr #0 !dbg !7 {
  %2 = call {}*** @julia.get_pgcstack()
  %3 = fcmp olt double %0, %1, !dbg !9
  %4 = bitcast double %1 to i64, !dbg !16
  %5 = bitcast double %0 to i64, !dbg !16
  %.not = icmp sgt i64 %4, -1, !dbg !19
  %6 = icmp slt i64 %5, 0, !dbg !24
  %7 = and i1 %6, %.not, !dbg !24
  %8 = or i1 %3, %7, !dbg !26
  %9 = fcmp ord double %0, 0.000000e+00, !dbg !28
  %10 = select i1 %9, double %1, double %0, !dbg !32
  %11 = fcmp ord double %1, 0.000000e+00, !dbg !28
  %12 = select i1 %11, double %0, double %1, !dbg !32
  %13 = select i1 %8, double %10, double %12, !dbg !32
  ret double %13, !dbg !15
jeremiedb commented 1 year ago

Trying to isolate further the issue with NNlib.conv, the call to NNlib.gemm! seems to be problematic, although it also results in a different error message. As gemm! has Val as inputs, could it be related to #654 ?

using Enzyme
using NNlib

function my_gemm!(y, x, w)
    x_ptr = pointer(x)
    w_ptr = pointer(w)
    y_ptr = pointer(y)
        size(x, 1),
        size(w, 2),
        size(x, 2),
    return y

x = rand(2, 3)
w = rand(3, 5)
y = zeros(2, 5)

dx = zeros(2, 3)
dw = zeros(3, 5)
dy = zeros(2, 5)

my_gemm!(y, x, w)
loss(y, x, w) = sum(my_gemm!(y, x, w))
loss(y, x, w)

autodiff(Reverse, loss, Duplicated(y, dy), Const(x), Duplicated(w, dw))
!509 = !DILocation(line: 150, scope: !30, inlinedAt: !510)
!510 = !DILocation(line: 14, scope: !499)
!511 = !DILocation(line: 8, scope: !35, inlinedAt: !512)
!512 = !DILocation(line: 104, scope: !38, inlinedAt: !513)
!513 = !DILocation(line: 412, scope: !41, inlinedAt: !514)
!514 = !DILocation(line: 48, scope: !44, inlinedAt: !510)
!515 = !DILocation(line: 26, scope: !499)

No augmented forward pass found for .text
declare void @.text(i8*, i8*, i8*, i8*, i8*, i8*, i64, i8*, i64, i8*, i8*, i64, i8*) local_unnamed_addr #6

  [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:4735
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\api.jl:123
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{typeof(loss), Tuple{Duplicated{Matrix{Float64}}, Const{Matrix{Float64}}, Duplicated{Matrix{Float64}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::NTuple{4, Bool}, returnPrimal::Bool, jlrules::Vector{String})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:6195
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:7446
  [5] _thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:7958 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}}})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:7952
  [7] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:7996
  [8] #s451#163
    @ C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:8056 [inlined]
  [9] var"#s451#163"(F::Any, Fn::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ReturnPrimal::Any, ShadowInit::Any, ::Any, #unused#::Type, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler .\none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:582
 [11] thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:8089 [inlined]
 [12] thunk(f::typeof(loss), df::Nothing, ::Type{Active{Float64}}, tt::Type{Tuple{Duplicated{Matrix{Float64}}, Const{Matrix{Float64}}, Duplicated{Matrix{Float64}}}}, ::Val{Enzyme.API.DEM_ReverseModeCombined}, ::Val{1}, ::Val{(false, false, false, false)}, ::Val{false})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:8082
 [13] autodiff(::EnzymeCore.ReverseMode{false, false}, ::typeof(loss), ::Type{Active{Float64}}, ::Duplicated{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\Enzyme.jl:197
 [14] autodiff(::EnzymeCore.ReverseMode{false, false}, ::typeof(loss), ::Duplicated{Matrix{Float64}}, ::Const{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\Enzyme.jl:223
 [15] top-level scope
    @ C:\Users\jerem\OneDrive\github\ADTests.jl\experiments\enzyme\conv-debug.jl:41
in expression starting at C:\Users\jerem\OneDrive\github\ADTests.jl\experiments\enzyme\conv-debug.jl:41
jakubMitura14 commented 10 months ago

Is it possible now to autodifferentiate convolutions in enzyme?

wsmoses commented 10 months ago

It should yes

jakubMitura14 commented 10 months ago

fantastic, thanks !

jeremiedb commented 10 months ago

@jakubMitura14 Did you manage to differentiate a convolution? On CPU, it still errors on my end on Enzyme#main. When calling the above my_gemm! which is the core operator called within NNLib's conv, it now results in the following stacktrace:

julia> autodiff(Reverse, loss, Duplicated(y, dy), Const(x), Duplicated(w, dw))
ERROR: Enzyme execution failed.
Enzyme compilation failed.
Current scope:
; Function Attrs: mustprogress uwtable willreturn
define internal fastcc void @preprocess_julia_my_gemm__12471({} addrspace(10)* nonnull align 16 dereferenceable(40) %0, {} addrspace(10)* nonnull align 16 dereferenceable(40) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) unnamed_addr #12 !dbg !477 {
  %3 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !478
  %4 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !478
  %5 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %6 = bitcast i8* %5 to i64*, !enzyme_caststack !4
  %7 = bitcast i64* %6 to i8*
  %8 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %9 = bitcast i8* %8 to i64*, !enzyme_caststack !4
  %10 = bitcast i64* %9 to i8*
  %11 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %12 = bitcast i8* %11 to i64*, !enzyme_caststack !4
  %13 = bitcast i64* %12 to i8*
  %14 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %15 = bitcast i8* %14 to i64*, !enzyme_caststack !4
  %16 = bitcast i64* %15 to i8*
  %17 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %18 = bitcast i8* %17 to i64*, !enzyme_caststack !4
  %19 = bitcast i64* %18 to i8*
  %20 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %21 = bitcast i8* %20 to i64*, !enzyme_caststack !4
  %22 = bitcast i64* %21 to i8*
  %23 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %24 = bitcast i8* %23 to i64*, !enzyme_caststack !4
  %25 = bitcast i64* %24 to i8*
  %26 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %27 = bitcast i8* %26 to i64*, !enzyme_caststack !4
  %28 = bitcast i64* %27 to i8*
  %29 = call {}*** @julia.get_pgcstack() #13
  %30 = addrspacecast {} addrspace(10)* %1 to {} addrspace(11)*, !dbg !480
  %31 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %30) #14, !dbg !480
  %32 = bitcast {}* %31 to i8**, !dbg !480
  %33 = load i8*, i8** %32, align 8, !dbg !480, !tbaa !42, !invariant.load !4, !nonnull !4
  %34 = ptrtoint i8* %33 to i64, !dbg !480
  %35 = addrspacecast {} addrspace(10)* %2 to {} addrspace(11)*, !dbg !483
  %36 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %35) #14, !dbg !483
  %37 = bitcast {}* %36 to i8**, !dbg !483
  %38 = load i8*, i8** %37, align 8, !dbg !483, !tbaa !42, !invariant.load !4, !nonnull !4
  %39 = ptrtoint i8* %38 to i64, !dbg !483
  %40 = addrspacecast {} addrspace(10)* %0 to {} addrspace(11)*, !dbg !486
  %41 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %40) #14, !dbg !486
  %42 = bitcast {}* %41 to i8**, !dbg !486
  %43 = load i8*, i8** %42, align 8, !dbg !486, !tbaa !42, !invariant.load !4, !nonnull !4
  %44 = ptrtoint i8* %43 to i64, !dbg !486
  %45 = bitcast {} addrspace(10)* %1 to {} addrspace(10)* addrspace(10)*, !dbg !489
  %46 = addrspacecast {} addrspace(10)* addrspace(10)* %45 to {} addrspace(10)* addrspace(11)*, !dbg !489
  %47 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %46, i64 3, !dbg !489
  %48 = bitcast {} addrspace(10)* addrspace(11)* %47 to i64 addrspace(11)*, !dbg !489
  %49 = load i64, i64 addrspace(11)* %48, align 8, !dbg !489, !tbaa !42, !range !46, !invariant.load !4
  %50 = bitcast {} addrspace(10)* %2 to {} addrspace(10)* addrspace(10)*, !dbg !489
  %51 = addrspacecast {} addrspace(10)* addrspace(10)* %50 to {} addrspace(10)* addrspace(11)*, !dbg !489
  %52 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %51, i64 4, !dbg !489
  %53 = bitcast {} addrspace(10)* addrspace(11)* %52 to i64 addrspace(11)*, !dbg !489
  %54 = load i64, i64 addrspace(11)* %53, align 16, !dbg !489, !tbaa !42, !range !46, !invariant.load !4
  %55 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %46, i64 4, !dbg !489
  %56 = bitcast {} addrspace(10)* addrspace(11)* %55 to i64 addrspace(11)*, !dbg !489
  %57 = load i64, i64 addrspace(11)* %56, align 16, !dbg !489, !tbaa !42, !range !46, !invariant.load !4
  call void @llvm.lifetime.start.p0i8(i64 noundef 1, i8* noundef nonnull %4) #13
  store i8 78, i8* %4, align 1, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 1, i8* noundef nonnull %3) #13
  store i8 78, i8* %3, align 1, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %7) #13
  store i64 %49, i64* %6, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %10) #13
  store i64 %54, i64* %9, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %13) #13
  store i64 %57, i64* %12, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %16) #13
  %58 = bitcast i64* %15 to double*, !dbg !491
  store double 1.000000e+00, double* %58, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %19) #13
  store i64 %49, i64* %18, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %22) #13
  store i64 %57, i64* %21, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %25) #13
  %59 = bitcast i64* %24 to double*, !dbg !491
  store double 0.000000e+00, double* %59, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %28) #13
  store i64 %49, i64* %27, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @.text(i8* noundef nonnull %4, i8* noundef nonnull %3, i8* noundef nonnull %7, i8* noundef nonnull %10, i8* noundef nonnull %13, i8* noundef nonnull %16, i64 %34, i8* noundef nonnull %19, i64 %39, i8* noundef nonnull %22, i8* noundef nonnull %25, i64 %44, i8* noundef nonnull %28) #13 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !494
  ret void, !dbg !498

No augmented forward pass found for .text
 at context:   call void @.text(i8* noundef nonnull %4, i8* noundef nonnull %3, i8* noundef nonnull %7, i8* noundef nonnull %10, i8* noundef nonnull %13, i8* noundef nonnull %16, i64 %34, i8* noundef nonnull %19, i64 %39, i8* noundef nonnull %22, i8* noundef nonnull %25, i64 %44, i8* noundef nonnull %28) #13 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !45

 [1] gemm!
   @ C:\Users\jerem\.julia\packages\NNlib\Fg3DQ\src\gemm.jl:48
 [2] my_gemm!
   @ c:\Users\jerem\OneDrive\github\ADTests.jl\experiments\enzyme\conv-debug.jl:14

 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\tZYHp\src\compiler.jl:1251
jakubMitura14 commented 10 months ago

No I did not yet tried it, I needed to work on sth different for some time

wsmoses commented 10 months ago

@jeremiedb you should be able to directly differentiate NNlib.conv (it has an EnzymeRule in NNlib)

jeremiedb commented 10 months ago

Oh I missed the newly added EnzymeRules in NNlib, thanks for pointing that out!

I just hit new issue however with Enzyme v0.11.10, NNlib v0.9.8, Julia 1.10-rc1, Windows.

using Enzyme
using NNlib

loss(w, x) = sum(conv(x, w))
w = randn(Float32, 3, 3, 5, 7);
dw = zero(w);
x = randn(Float32, (3, 3, 5, 8));
loss(w, x);
grads = Enzyme.autodiff(Reverse, loss, Duplicated(w, dw), Const(x));

The call to loss works fine, but then the autodiff results in the following:

julia> grads = Enzyme.autodiff(Reverse, loss, Duplicated(w, dw), Const(x))
ERROR: AssertionError: legal
  [1] array_shadow_handler(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, numArgs::UInt64, Args::Ptr{Ptr{…}}, gutils::Ptr{Nothing})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:982
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo,
 uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\api.jl:141
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, jlrules::Vector{…}, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:7726
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9278
  [5] codegen
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:8886 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9830
  [7] cached_compilation
    @ C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9864 [inlined]
  [8] (::Enzyme.Compiler.var"#474#475"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{…}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9921
  [9] JuliaContext(f::Enzyme.Compiler.var"#474#475"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{…}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler C:\Users\jerem\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:47
 [10] #s325#473
    @ C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9882 [inlined]
    @ Enzyme.Compiler .\none:0
 [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core .\boot.jl:600
 [13] autodiff
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\Enzyme.jl:207 [inlined]
 [14] autodiff
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\Enzyme.jl:236 [inlined]
 [15] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Duplicated{Array{Float32, 4}}, ::Const{Array{Float32, 4}})
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\Enzyme.jl:222
 [16] top-level scope
    @ REPL[16]:1
Some type information was truncated. Use `show(err)` to see complete types.

I shall be able to test on a Linux machne tomorrow, in case it's a Windows specific issue.

jeremiedb commented 10 months ago

@wsmoses I've isolated MWE of the above "legal" error that arise (both on Windows/Ubuntu). For illustration, the following is a 2-level loop that works fine:

using Enzyme

function my_conv_1(x, w)
    y = zero(x)
    for b in axes(y, 3)
        for wi in axes(y, 2)
            y[:, wi, b] .= w .* x[:, wi, b]
    return y
x = rand(Float32, 3, 5, 8);
w = rand(Float32, 3);
y = my_conv_1(x, w);
loss1(x, w) = sum(my_conv_1(x, w))
dw = zero(w);
loss1(x, w)
grads = Enzyme.autodiff(Reverse, loss1, Const(x), Duplicated(w, dw));

However, when adding another dimension to the data, it errors:

function my_conv_2(x, w)
    y = zero(x)
    for b in axes(y, 4)
        for hi in axes(y, 3)
            for wi in axes(y, 2)
                y[:, wi, hi, b] .= w .* x[:, wi, hi, b]
    return y
x = rand(Float32, 3, 5, 5, 8);
w = rand(Float32, 3);
y = my_conv_2(x, w);
loss2(x, w) = sum(my_conv_2(x, w))
dw = zero(w);
loss2(x, w)
grads = Enzyme.autodiff(Reverse, loss2, Const(x), Duplicated(w, dw));


ERROR: AssertionError: legal
  [1] array_shadow_handler(B::Ptr{…}, OrigCI::Ptr{…}, numArgs::UInt64, Args::Ptr{…}, gutils::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:982
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/rbuCz/src/api.jl:141
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, jlrules::Vector{…}, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:7726
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9278
  [5] codegen
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:8886 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9830
  [7] cached_compilation
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9864 [inlined]
  [8] (::Enzyme.Compiler.var"#474#475"{…})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9921
  [9] JuliaContext(f::Enzyme.Compiler.var"#474#475"{…})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
 [10] #s325#473
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9882 [inlined]
    @ Enzyme.Compiler ./none:0
 [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:600
 [13] autodiff
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:207 [inlined]
 [14] autodiff
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:236 [inlined]
 [15] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss2), ::Const{Array{Float32, 4}}, ::Duplicated{Vector{Float32}})
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:222
 [16] top-level scope
    @ ~/github/ADTests.jl/experiments/conv-v3.jl:1
Some type information was truncated. Use `show(err)` to see complete types.

Curiously, it also errors if the above only performs a single loop:

function my_conv_3(x, w)
    y = zero(x)
    for hi in axes(y, 3)
        y[1] += w[1] * x[1]
    return y
x = rand(Float32, 3, 5, 5, 8);
w = rand(Float32, 3);
y = my_conv_3(x, w);
loss3(x, w) = sum(my_conv_3(x, w))
dw = zero(w);
loss3(x, w)
grads = Enzyme.autodiff(Reverse, loss3, Const(x), Duplicated(w, dw));

From the above, it looks like there's something happening in the handling of Arrays of 4 or more dimensions. Did I miss something obvious?

Status `~/github/ADTests.jl/Project.toml`
  [052768ef] CUDA v5.1.1
  [d360d2e6] ChainRulesCore v1.18.0
  [7da242da] Enzyme v0.11.10
  [587475ba] Flux v0.14.6
  [bdcacae8] LoopVectorization v0.12.166
  [872c559c] NNlib v0.9.8
  [3bd65402] Optimisers v0.3.1
  [37e2e3b7] ReverseDiff v1.15.1
  [bc48ee85] Tullio v0.3.7
  [cd998857] Yota v0.8.5
  [e88e6eb3] Zygote v0.6.67
wsmoses commented 10 months ago

@jeremiedb issue wasn't anything to do with nnlib, just julia's special case handling for arrays of size 1, 2, 3. Should be fixed in https://github.com/EnzymeAD/Enzyme.jl/pull/1157