SciML / Optimization.jl

Mathematical Optimization in Julia. Local, global, gradient-based and derivative-free. Linear, Quadratic, Convex, Mixed-Integer, and Nonlinear Optimization in one simple, fast, and differentiable interface.
https://docs.sciml.ai/Optimization/stable/
MIT License
704 stars 77 forks source link

a bug in Optimization.jls enzyme extension #621

Closed enigne closed 9 months ago

enigne commented 9 months ago

Hi, I got a bug when running the following code

using Enzyme
using LinearAlgebra
using Optimization, OptimizationOptimJL

n = 4

function mylsovle(b::Vector{Float64}, A::Matrix{Float64})::Float64
   x = A\b
   norm(x)
end
function enzymerule(f, b::Vector{Float64}, A::Matrix{Float64})
   dA = zero(A)
   db = zero(b)

   Enzyme.autodiff(Reverse, f, Duplicated(b, db), Duplicated(A, dA))
   return db
end

b0 = rand(n)
A = rand(n, n)
optprob = OptimizationFunction(mylsovle, Optimization.AutoEnzyme(), grad=enzymerule)
prob = Optimization.OptimizationProblem(optprob, b0, A, lb = -ones(n), ub = ones(n))
sol = solve(prob, Optim.LBFGS())

The error message is as following:

ERROR: MethodError: no method matching Vector{Float64}(::Vector{Float64}, ::linearSys)

Closest candidates are:
  Array{T, N}(::Nothing, ::Any...) where {T, N}
   @ Base baseext.jl:42
  Array{T, N}(::Missing, ::Any...) where {T, N}
   @ Base baseext.jl:43
  Array{T, N}(::AbstractArray{S, N}) where {T, N, S}
   @ Base array.jl:671

Stacktrace:
  [1] macro expansion
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/utils.jl:0 [inlined]
  [2] codegen_world_age(ft::Type{Vector{Float64}}, tt::Type{Tuple{Vector{Float64}, linearSys}})
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/utils.jl:141
  [3] autodiff
    @ Main ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:189 [inlined]
  [4] autodiff
    @ Main ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:236 [inlined]
  [5] autodiff
    @ Main ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:222 [inlined]
  [6] gf(f::Vector{Float64}, b::Vector{Float64}, Q::linearSys)
    @ Main ./REPL[14]:5
  [7] (::Optimization.var"#20#27"{OptimizationFunction{…}, Optimization.ReInitCache{…}})(::Vector{Float64}, ::Vector{Float64})
    @ Optimization ~/.julia/packages/Optimization/fPVIW/src/function.jl:77
  [8] (::OptimizationOptimJL.var"#19#23"{OptimizationCache{…}, OptimizationOptimJL.var"#18#22"{…}})(G::Vector{Float64}, θ::Vector{Float64})
    @ OptimizationOptimJL ~/.julia/packages/OptimizationOptimJL/o91yE/src/OptimizationOptimJL.jl:274
  [9] value_gradient!!(obj::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, x::Vector{Float64})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
 [10] value_gradient!!(bw::Optim.BarrierWrapper{OnceDifferentiable{Float64, Vector{…}, Vector{…}}, Optim.BoxBarrier{Vector{…}, Vector{…}}, Float64, Float64, Vector{Float64}}, x::Vector{Float64})
    @ Optim ~/.julia/packages/Optim/V8ZEC/src/multivariate/solvers/constrained/fminbox.jl:81
 [11] initial_state(method::LBFGS{…}, options::Optim.Options{…}, d::Optim.BarrierWrapper{…}, initial_x::Vector{…})
    @ Optim ~/.julia/packages/Optim/V8ZEC/src/multivariate/solvers/first_order/l_bfgs.jl:164
 [12] optimize(df::OnceDifferentiable{…}, l::Vector{…}, u::Vector{…}, initial_x::Vector{…}, F::Fminbox{…}, options::Optim.Options{…})
    @ Optim ~/.julia/packages/Optim/V8ZEC/src/multivariate/solvers/constrained/fminbox.jl:322
 [13] __solve(cache::OptimizationCache{OptimizationFunction{…}, Optimization.ReInitCache{…}, Vector{…}, Vector{…}, Nothing, Nothing, Nothing, Fminbox{…}, Base.Iterators.Cycle{…}, Bool, OptimizationOptimJL.var"#3#5"})
    @ OptimizationOptimJL ~/.julia/packages/OptimizationOptimJL/o91yE/src/OptimizationOptimJL.jl:298
 [14] solve!
    @ SciMLBase ~/.julia/packages/SciMLBase/VS2ST/src/solve.jl:162 [inlined]
 [15] #solve#619
    @ SciMLBase ~/.julia/packages/SciMLBase/VS2ST/src/solve.jl:83 [inlined]
 [16] solve(::OptimizationProblem{…}, ::LBFGS{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/VS2ST/src/solve.jl:80
 [17] top-level scope
    @ REPL[23]:1

Could you please help me to fix this? Thank you. @wsmoses

Vaibhavdixit02 commented 9 months ago

There's multiple things going on here, your function enzymerule has incorrect arguments to be a valid gradient function in Optimization.jl

function enzymerule(G, b::Vector{Float64}, p, f, A::Matrix{Float64})
                 dA = zero(A)

                 Enzyme.autodiff(Reverse, f, Duplicated(b, G), Duplicated(A, dA))
              end

is the corrected one. Though this one gives me another error which I am not sure about

julia> sol = solve(prob, Optim.LBFGS())
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
 caching call:   %37 = call fastcc i64 @julia_steprange_last_2801(i64 signext %36, i64 noundef signext 4, i64 signext %34) #91, !dbg !166
 caching call:   %40 = call fastcc i64 @julia_steprange_last_2801(i64 signext %39, i64 noundef signext 4, i64 signext %37) #91, !dbg !166
warning: didn't implement memmove, using memcpy as fallback which can result in errors
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
 caching call:   %37 = call fastcc i64 @julia_steprange_last_5635(i64 signext %36, i64 noundef signext 4, i64 signext %34) #91, !dbg !166
 caching call:   %40 = call fastcc i64 @julia_steprange_last_5635(i64 signext %39, i64 noundef signext 4, i64 signext %37) #91, !dbg !166
warning: didn't implement memmove, using memcpy as fallback which can result in errors
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
 caching call:   %37 = call fastcc i64 @julia_steprange_last_5957(i64 signext %36, i64 noundef signext 4, i64 signext %34) #91, !dbg !166
 caching call:   %40 = call fastcc i64 @julia_steprange_last_5957(i64 signext %39, i64 noundef signext 4, i64 signext %37) #91, !dbg !166
warning: didn't implement memmove, using memcpy as fallback which can result in errors
ERROR: Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia__getrf__1_6208({ {} addrspace(10)*, {} addrspace(10)*, i64 }* noalias nocapture nofree noundef nonnull writeonly sret({ {} addrspace(10)*, {} addrspace(10)*, i64 }) align 8 dereferenceable(24) %0, [2 x {} addrspace(10)*]* noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(16) "enzyme_inactive" "enzymejl_returnRoots" %1, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %2) unnamed_addr #98 !dbg !7177 {
top:
  %3 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !141
  %4 = bitcast i8* %3 to i64*, !enzyme_caststack !121
  %5 = bitcast i64* %4 to i8*
  %6 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !141
  %7 = bitcast i8* %6 to i64*, !enzyme_caststack !121
  %8 = bitcast i64* %7 to i8*
  %9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !141
  %10 = bitcast i8* %9 to i64*, !enzyme_caststack !121
  %11 = bitcast i64* %10 to i8*
  %12 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !141
  %13 = bitcast i8* %12 to i64*, !enzyme_caststack !121
  %14 = bitcast i64* %13 to i8*
  %15 = call {}*** @julia.get_pgcstack() #117
  %16 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !7178
  %17 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %16 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !7178
  %18 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %17, i64 0, i32 1, !dbg !7178
  %19 = load i64, i64 addrspace(11)* %18, align 8, !dbg !7178, !tbaa !128, !range !144, !invariant.load !121, !alias.scope !132, !noalias !135
  %.not16 = icmp eq i64 %19, 0, !dbg !7183
  br i1 %.not16, label %L51, label %L18, !dbg !7179

L18:                                              ; preds = %top
  %20 = bitcast {} addrspace(10)* %2 to double addrspace(13)* addrspace(10)*, !dbg !7185
  %21 = addrspacecast double addrspace(13)* addrspace(10)* %20 to double addrspace(13)* addrspace(11)*, !dbg !7185
  %22 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %21, align 8, !dbg !7185, !tbaa !128, !invariant.load !121, !alias.scope !7186, !noalias !135, !nonnull !121
  %value_phi341 = load double, double addrspace(13)* %22, align 8, !dbg !7189, !tbaa !207, !alias.scope !163, !noalias !383
  %23 = fsub double %value_phi341, %value_phi341, !dbg !7190
  %24 = fcmp oeq double %23, 0.000000e+00, !dbg !7193
  br i1 %24, label %L31.lr.ph, label %L28, !dbg !7192

L31.lr.ph:                                        ; preds = %L18
  %25 = add nuw nsw i64 %19, 1, !dbg !7192
  br label %L31, !dbg !7192

L28.loopexit:                                     ; preds = %L43
  br label %L28, !dbg !7195

L28:                                              ; preds = %L28.loopexit, %L18
  %26 = call fastcc [1 x {} addrspace(10)*] @julia_ArgumentError_5963({} addrspace(10)* noalias nofree noundef nonnull readnone align 16 addrspacecast ({}* inttoptr (i64 4865111120 to {}*) to {} addrspace(10)*)) #118, !dbg !7195
  %current_task517 = getelementptr inbounds {}**, {}*** %15, i64 -13, !dbg !7195
  %current_task5 = bitcast {}*** %current_task517 to {}**, !dbg !7195
  %27 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task5, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4896143584 to {}*) to {} addrspace(10)*)) #119, !dbg !7195
  %28 = extractvalue [1 x {} addrspace(10)*] %26, 0, !dbg !7195
  %29 = bitcast {} addrspace(10)* %27 to {} addrspace(10)* addrspace(10)*, !dbg !7195
  store {} addrspace(10)* %28, {} addrspace(10)* addrspace(10)* %29, align 8, !dbg !7195, !tbaa !159, !alias.scope !163, !noalias !7196
  %30 = addrspacecast {} addrspace(10)* %27 to {} addrspace(12)*, !dbg !7195
  call void @ijl_throw({} addrspace(12)* %30) #120, !dbg !7195
  unreachable, !dbg !7195

L31:                                              ; preds = %L43, %L31.lr.ph
  %iv = phi i64 [ %iv.next, %L43 ], [ 0, %L31.lr.ph ]
  %31 = add i64 %iv, 2, !dbg !7197
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !7197
  %exitcond.not = icmp eq i64 %31, %25, !dbg !7197
  br i1 %exitcond.not, label %L51.loopexit, label %L43, !dbg !7199

L43:                                              ; preds = %L31
  %32 = add nsw i64 %31, -1, !dbg !7201
  %33 = getelementptr inbounds double, double addrspace(13)* %22, i64 %32, !dbg !7203
  %34 = add nuw i64 %31, 1, !dbg !7204
  %value_phi3 = load double, double addrspace(13)* %33, align 8, !dbg !7189, !tbaa !207, !alias.scope !163, !noalias !383
  %35 = fsub double %value_phi3, %value_phi3, !dbg !7190
  %36 = fcmp oeq double %35, 0.000000e+00, !dbg !7193
  br i1 %36, label %L31, label %L28.loopexit, !dbg !7192

L51.loopexit:                                     ; preds = %L31
  br label %L51, !dbg !7205

L51:                                              ; preds = %L51.loopexit, %top
  %37 = bitcast {} addrspace(10)* %2 to {} addrspace(10)* addrspace(10)*, !dbg !7205
  %38 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(10)* %37, i64 3, !dbg !7205
  %39 = bitcast {} addrspace(10)* addrspace(10)* %38 to i64 addrspace(10)*, !dbg !7205
  %40 = addrspacecast i64 addrspace(10)* %39 to i64 addrspace(11)*, !dbg !7205
  %41 = load i64, i64 addrspace(11)* %40, align 8, !dbg !7205, !tbaa !128, !range !144, !invariant.load !121, !alias.scope !132, !noalias !135
  %42 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(10)* %37, i64 4, !dbg !7205
  %43 = bitcast {} addrspace(10)* addrspace(10)* %42 to i64 addrspace(10)*, !dbg !7205
  %44 = addrspacecast i64 addrspace(10)* %43 to i64 addrspace(11)*, !dbg !7205
  %45 = load i64, i64 addrspace(11)* %44, align 16, !dbg !7205, !tbaa !128, !range !144, !invariant.load !121, !alias.scope !132, !noalias !135
  %.not19 = icmp eq i64 %41, 0, !dbg !7207
  %46 = select i1 %.not19, i64 1, i64 %41, !dbg !7210
  %.not20 = icmp ult i64 %45, %41, !dbg !7211
  %47 = select i1 %.not20, i64 %45, i64 %41, !dbg !7214
  %48 = call noalias nonnull {} addrspace(10)* @ijl_alloc_array_1d({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4811830208 to {}*) to {} addrspace(10)*), i64 %47) #121, !dbg !7215
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %5) #117
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %8) #117
  store i64 %41, i64* %7, align 16, !dbg !7219, !tbaa !373, !alias.scope !163, !noalias !7196
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %11) #117
  store i64 %45, i64* %10, align 16, !dbg !7219, !tbaa !373, !alias.scope !163, !noalias !7196
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %14) #117
  store i64 %46, i64* %13, align 16, !dbg !7219, !tbaa !373, !alias.scope !163, !noalias !7196
  %49 = addrspacecast {} addrspace(10)* %2 to {} addrspace(11)*, !dbg !7223
  %50 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* noundef %49) #122, !dbg !7223
  %51 = bitcast {}* %50 to i8**, !dbg !7223
  %52 = load i8*, i8** %51, align 8, !dbg !7223, !tbaa !128, !invariant.load !121, !alias.scope !132, !noalias !135, !nonnull !121
  %53 = ptrtoint i8* %52 to i64, !dbg !7223
  %54 = addrspacecast {} addrspace(10)* %48 to {} addrspace(11)*, !dbg !7223
  %55 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %54) #122, !dbg !7223
  %56 = bitcast {}* %55 to i8**, !dbg !7223
  %57 = load i8*, i8** %56, align 8, !dbg !7223, !tbaa !204, !alias.scope !150, !noalias !151, !nonnull !121
  %58 = ptrtoint i8* %57 to i64, !dbg !7223
  %59 = ptrtoint i64* %4 to i64, !dbg !7224
  call void @dgetrf_64_(i8* noundef nonnull %8, i8* noundef nonnull %11, i64 %53, i8* noundef nonnull %14, i64 %58, i64 noundef %59) #117 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* %48, {} addrspace(10)* null, {} addrspace(10)* %2, {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !7222
  %60 = load i64, i64* %4, align 16, !dbg !7226, !tbaa !373, !alias.scope !163, !noalias !383
  %61 = icmp sgt i64 %60, -1, !dbg !7229
  br i1 %61, label %L84, label %L78, !dbg !7230

L78:                                              ; preds = %L51
  %current_task921 = getelementptr inbounds {}**, {}*** %15, i64 -13, !dbg !7231
  %current_task9 = bitcast {}*** %current_task921 to {}**, !dbg !7231
  %62 = sub i64 0, %60, !dbg !7234
  call void @llvm.lifetime.end.p0i8(i64 noundef 8, i8* noundef nonnull %5) #117
  call void @llvm.lifetime.end.p0i8(i64 noundef 8, i8* noundef nonnull %8) #117
  call void @llvm.lifetime.end.p0i8(i64 noundef 8, i8* noundef nonnull %11) #117
  call void @llvm.lifetime.end.p0i8(i64 noundef 8, i8* noundef nonnull %14) #117
  %63 = call noalias nonnull {} addrspace(10)* @ijl_box_int64(i64 signext %62) #121, !dbg !7235
  %64 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)*, {} addrspace(10)*, {} addrspace(10)*, ...) @julia.call2({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* noundef nonnull @ijl_invoke, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4848051248 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4815756384 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4865151776 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %63, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4865151744 to {}*) to {} addrspace(10)*)) #123, !dbg !7235
  %65 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task9, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4896143584 to {}*) to {} addrspace(10)*)) #119, !dbg !7235
  %66 = bitcast {} addrspace(10)* %65 to {} addrspace(10)* addrspace(10)*, !dbg !7235
  store {} addrspace(10)* %64, {} addrspace(10)* addrspace(10)* %66, align 8, !dbg !7235, !tbaa !159, !alias.scope !163, !noalias !7196
  %67 = addrspacecast {} addrspace(10)* %65 to {} addrspace(12)*, !dbg !7235
  call void @ijl_throw({} addrspace(12)* %67) #120, !dbg !7235
  unreachable, !dbg !7235

L84:                                              ; preds = %L51
  %68 = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*]* %1, i64 0, i64 0, !dbg !7236
  store {} addrspace(10)* %2, {} addrspace(10)** %68, align 8, !dbg !7236, !noalias !7237
  %69 = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*]* %1, i64 0, i64 1, !dbg !7236
  store {} addrspace(10)* %48, {} addrspace(10)** %69, align 8, !dbg !7236, !noalias !7237
  %.repack = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, i64 }, { {} addrspace(10)*, {} addrspace(10)*, i64 }* %0, i64 0, i32 0, !dbg !7236
  store {} addrspace(10)* %2, {} addrspace(10)** %.repack, align 8, !dbg !7236, !noalias !7237
  %.repack26 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, i64 }, { {} addrspace(10)*, {} addrspace(10)*, i64 }* %0, i64 0, i32 1, !dbg !7236
  store {} addrspace(10)* %48, {} addrspace(10)** %.repack26, align 8, !dbg !7236, !noalias !7237
  %.repack28 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, i64 }, { {} addrspace(10)*, {} addrspace(10)*, i64 }* %0, i64 0, i32 2, !dbg !7236
  store i64 %60, i64* %.repack28, align 8, !dbg !7236, !noalias !7237
  ret void, !dbg !7236
}

Illegal replace ficticious phi for:   %_replacementA35 = phi {} addrspace(10)* , !dbg !149 of   %48 = call noalias nonnull {} addrspace(10)* @ijl_alloc_array_1d({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4811830208 to {}*) to {} addrspace(10)*), i64 %47) #121, !dbg !198
; Function Attrs: mustprogress willreturn
define internal fastcc void @diffejulia__getrf__1_6208({ {} addrspace(10)*, {} addrspace(10)*, i64 }* noalias nocapture nofree writeonly sret({ {} addrspace(10)*, {} addrspace(10)*, i64 }) align 8 dereferenceable(24) "enzyme_sret" %0, { {} addrspace(10)*, {} addrspace(10)*, i64 }* nocapture nofree align 8 "enzyme_sret" %"'", [2 x {} addrspace(10)*]* noalias nocapture nofree writeonly align 8 dereferenceable(16) "enzyme_inactive" "enzymejl_returnRoots" %1, {} addrspace(10)* align 16 dereferenceable(40) %2, {} addrspace(10)* align 16 %"'1", { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg) unnamed_addr #98 !dbg !12088 {
top:
  %3 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 3
  %4 = bitcast i8* %3 to i64*, !enzyme_caststack !121
  %5 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 2
  %6 = bitcast i8* %5 to i64*, !enzyme_caststack !121
  %7 = bitcast i64* %6 to i8*
  %8 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 1
  %9 = bitcast i8* %8 to i64*, !enzyme_caststack !121
  %10 = bitcast i64* %9 to i8*
  %11 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 0
  %12 = bitcast i8* %11 to i64*, !enzyme_caststack !121
  %13 = bitcast i64* %12 to i8*
  %14 = call {}*** @julia.get_pgcstack() #117
  %_replacementA6 = phi { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* , !dbg !12089
  %_replacementA5 = phi { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* , !dbg !12089
  %15 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 7, !dbg !12094
  %.not16 = icmp eq i64 %15, 0, !dbg !12094
  br i1 %.not16, label %L51, label %L18, !dbg !12090

L18:                                              ; preds = %top
  br i1 true, label %L31.lr.ph, label %L28, !dbg !12096

L31.lr.ph:                                        ; preds = %L18
  %16 = add nuw nsw i64 %15, 1, !dbg !12096
  %17 = add nsw i64 %15, -1, !dbg !12096
  br label %L31, !dbg !12096

L28.loopexit:                                     ; preds = %L43
  unreachable

L28:                                              ; preds = %L18
  %current_task5_replacementA = phi {}** , !dbg !12097
  %_replacementA16 = phi {} addrspace(10)* , !dbg !12097
  unreachable

L31:                                              ; preds = %L43, %L31.lr.ph
  %iv = phi i64 [ %iv.next, %L43 ], [ 0, %L31.lr.ph ]
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !12098
  %18 = add i64 %iv, 2, !dbg !12098
  %exitcond.not = icmp eq i64 %18, %16, !dbg !12098
  br i1 %exitcond.not, label %L51.loopexit, label %L43, !dbg !12100

L43:                                              ; preds = %L31
  br i1 true, label %L31, label %L28.loopexit, !dbg !12096

L51.loopexit:                                     ; preds = %L31
  br label %L51, !dbg !12102

L51:                                              ; preds = %L51.loopexit, %top
  %_replacementA45 = phi {} addrspace(10)* addrspace(10)* , !dbg !12102
  %_replacementA42 = phi i64 addrspace(11)* , !dbg !12102
  %_replacementA40 = phi {} addrspace(10)* addrspace(10)* , !dbg !12102
  %_replacementA38 = phi i64 addrspace(11)* , !dbg !12102
  %.not19_replacementA = phi i1 , !dbg !12104
  %19 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 8, !dbg !12107
  %20 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 6, !dbg !12107
  %_replacementA35 = phi {} addrspace(10)* , !dbg !12107
  %"'ipc32" = addrspacecast {} addrspace(10)* %"'1" to {} addrspace(11)*, !dbg !12112
  %21 = call {}* @julia.pointer_from_objref({} addrspace(11)* %"'ipc32"), !dbg !12112
  %"'il_phi2" = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 5, !dbg !12112
  %"'ipc" = ptrtoint i8* %"'il_phi2" to i64, !dbg !12112
  %22 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 9, !dbg !12112
  %_replacementA27 = phi i8** , !dbg !12112
  %23 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 10, !dbg !12114
  %24 = ptrtoint i64* %4 to i64, !dbg !12114
  %tapeArg25 = extractvalue { i8*, i8*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, i64, i64, i64, i64 } %tapeArg, 4, !dbg !12113
  br i1 true, label %L84, label %L78, !dbg !12116

L78:                                              ; preds = %L51
  %current_task9_replacementA = phi {}** , !dbg !12118
  %_replacementA47 = phi {} addrspace(10)* addrspace(10)* , !dbg !12121
  unreachable

L84:                                              ; preds = %L51
  %.repack_replacementA = phi {} addrspace(10)** , !dbg !12122
  br label %invertL84, !dbg !12122

allocsForInversion:                               ; No predecessors!
  %"iv'ac" = alloca i64, align 8

inverttop:                                        ; preds = %invertL51, %invertL18
  call void @free(i8* %11)
  call void @free(i8* %8)
  call void @free(i8* %5)
  call void @free(i8* %3)
  ret void

invertL18:                                        ; preds = %invertL31.lr.ph
  br label %inverttop

invertL31.lr.ph:                                  ; preds = %invertL31
  br label %invertL18

invertL28.loopexit:                               ; No predecessors!

invertL28:                                        ; No predecessors!

invertL31:                                        ; preds = %mergeinvertL31_L51.loopexit, %invertL43
  %25 = load i64, i64* %"iv'ac", align 8
  %26 = icmp eq i64 %25, 0
  %27 = xor i1 %26, true
  br i1 %26, label %invertL31.lr.ph, label %incinvertL31

incinvertL31:                                     ; preds = %invertL31
  %28 = load i64, i64* %"iv'ac", align 8
  %29 = add nsw i64 %28, -1
  store i64 %29, i64* %"iv'ac", align 8
  br label %invertL43

invertL43:                                        ; preds = %incinvertL31
  br label %invertL31

invertL51.loopexit:                               ; preds = %invertL51
  %_unwrap = add nsw i64 %15, -1
  br label %mergeinvertL31_L51.loopexit

mergeinvertL31_L51.loopexit:                      ; preds = %invertL51.loopexit
  store i64 %_unwrap, i64* %"iv'ac", align 8
  br label %invertL31

invertL51:                                        ; preds = %invertL84
  call void inttoptr (i64 7235321712 to void (i8*)*)(i8* getelementptr inbounds ([12211 x i8], [12211 x i8]* @15, i32 0, i32 0)) #118, !dbg !12113
  call void @diffedgetrf_64_(i8* %7, i8* %10, i64 %22, i64 %"'ipc", i8* %13, i64 %23, i64 %24, {} addrspace(10)* %tapeArg25) [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* %_replacementA35, {} addrspace(10)* %20, {} addrspace(10)* null, {} addrspace(10)* %2, {} addrspace(10)* %"'1", {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !12113
  br i1 %.not16, label %inverttop, label %invertL51.loopexit

invertL78:                                        ; No predecessors!

invertL84:                                        ; preds = %L84
  br label %invertL51
}

LLVM.CallInst(%48 = call noalias nonnull {} addrspace(10)* @ijl_alloc_array_1d({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4811830208 to {}*) to {} addrspace(10)*), i64 %47) #121, !dbg !198)
LLVM.PHIInst(%_replacementA35 = phi {} addrspace(10)* , !dbg !149)

Stacktrace:
 [1] Array
   @ ./boot.jl:477
 [2] Array
   @ ./boot.jl:486
 [3] similar
   @ ./array.jl:374
 [4] similar
   @ ./abstractarray.jl:838
 [5] #getrf!#1
   @ ~/.julia/juliaup/julia-1.9.2+0.x64.apple.darwin14/share/julia/stdlib/v1.9/LinearAlgebra/src/lapack.jl:563

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:5891
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/5wFGb/src/api.jl:141
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type, loweredArgs::Set{Int64}, boxedArgs::Set{Int64})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:7715
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:9372
  [5] codegen
    @ ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:8980 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:9924
  [7] _thunk
    @ ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:9924 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:9958 [inlined]
  [9] (::Enzyme.Compiler.var"#473#474"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:10015
 [10] JuliaContext(f::Enzyme.Compiler.var"#473#474"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
 [11] #s312#472
    @ ~/.julia/packages/Enzyme/5wFGb/src/compiler.jl:9976 [inlined]
 [12] var"#s312#472"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [13] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [14] autodiff
    @ ~/.julia/packages/Enzyme/5wFGb/src/Enzyme.jl:207 [inlined]
 [15] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(mylsovle)}, ::Duplicated{Vector{Float64}}, ::Duplicated{Matrix{Float64}})
    @ Enzyme ~/.julia/packages/Enzyme/5wFGb/src/Enzyme.jl:236
 [16] autodiff
    @ ~/.julia/packages/Enzyme/5wFGb/src/Enzyme.jl:222 [inlined]
 [17] enzymerule(G::Vector{Float64}, b::Vector{Float64}, p::Matrix{Float64}, f::Function, A::Matrix{Float64})
    @ Main ./REPL[16]:4
 [18] (::var"#5#6")(G::Vector{Float64}, x::Vector{Float64}, p::Matrix{Float64})
    @ Main ./REPL[20]:1
ChrisRackauckas commented 9 months ago

It doesn't support Lapack at this time, so you'd need to use LinearSolve.jl.

But there's no reason to supply that "enzymerule": just use AutoEnzyme and it's the same thing. If your function f is supported by Enzyme, it'll create that derivative definition and do it.

Vaibhavdixit02 commented 9 months ago

lol I was just about to post the same thing as Chris said. @enigne the derivatives generated by Optimization.jl will probably be better for you

wsmoses commented 9 months ago

The rule \ he's using is actually explicitly supported by Enzyme, in a PR he actually made (https://github.com/EnzymeAD/Enzyme.jl/pull/1111)!

@enigne you may need to update to the latest Enzyme (0.11.10, which has your PR in it).

Vaibhavdixit02 commented 9 months ago

julia> optprob = OptimizationFunction(mylsovle, Optimization.AutoEnzyme())
(::OptimizationFunction{true, AutoEnzyme{Nothing}, typeof(mylsovle), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}) (generic function with 1 method)

julia> prob = Optimization.OptimizationProblem(optprob, b0, A, lb = -ones(n), ub = ones(n))
OptimizationProblem. In-place: true
u0: 4-element Vector{Float64}:
 0.4603304588047187
 0.5287713607951939
 0.9174186008736265
 0.0450911861077159

julia> sol = solve(prob, Optim.LBFGS())
ERROR: type Const has no field dval

on Enzyme 0.10.11

enigne commented 9 months ago

I used Enzyme 0.11.10 to generate this error. Actually, AutoEnzyme() without the derivative does not work for me either.

ChrisRackauckas commented 9 months ago

Can you isolate this to just Enzyme no optimization?

enigne commented 9 months ago

The Enzyme part works, thanks to @wsmoses. Enzyme now supports backslash operator. I can run enzymerule(mylsovle, rand(n), rand(n,n)) without any error. I think the bug here is in OptimizationEnzymeExt.jl, which I don't quite understand

ChrisRackauckas commented 9 months ago

A is a parameter, so it should be Const in the differentiation example?

ChrisRackauckas commented 9 months ago

; Function Attrs: mustprogress willreturn define internal fastcc void @preprocess_juliagetrf1_6208({ {} addrspace(10), {} addrspace(10), i64 } noalias nocapture nofree noundef nonnull writeonly sret({ {} addrspace(10), {} addrspace(10), i64 }) align 8 dereferenceable(24) %0, [2 x {} addrspace(10)] noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(16) "enzyme_inactive" "enzymejl_returnRoots" %1, {} addrspace(10) noundef nonnull align 16 dereferenceable(40) %2) unnamed_addr #98 !dbg !7177 { top:

it's pointing to the backsolve (ldiv!) part of it, so is the dispatch not allowing Const?

enigne commented 9 months ago

A is a parameter, so it should be Const in the differentiation example?

This is just a simplified example. In my read case, A is a struct which contains b, but we only need the gradient of f with respect to b

ChrisRackauckas commented 9 months ago

yes and it should use the Const type. I'm trying to help you build the MWE that isn't hitting the dispatch.

wsmoses commented 9 months ago

Looks like the internal \ rule didn't permit a constant A or b, try this: https://github.com/EnzymeAD/Enzyme.jl/pull/1121

Vaibhavdixit02 commented 9 months ago

Tested with Enzyme main and this works now