EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
440 stars 62 forks source link

Const kwargs becoming Duplicated in custom rule => error #1845

Open danielwe opened 4 days ago

danielwe commented 4 days ago

MWE:

using Enzyme, QuadGK

function polyintegral(coeffs, config)
    f(x) = evalpoly(x, coeffs)
    return first(quadgk(f, -1.0, 1.0; config...))
end

coeffs = (1.0,)
config = (; rtol=1e-10)
autodiff(Reverse, polyintegral, Active, Active(coeffs), Const(config))

Output:

ERROR: LoadError: Enzyme execution failed.
Enzyme: Non-constant keyword argument found for Tuple{UInt64, typeof(Core.kwcall), Duplicated{@NamedTuple{rtol::Float64}}, typeof(EnzymeCore.EnzymeRules.augmented_primal), EnzymeCore.EnzymeRules.ConfigWidth{1, true, false, (false, false, false, false)}, Const{typeof(quadgk)}, Type{Active{Tuple{Float64, Float64}}}, Active{var"#f#17"{Tuple{Float64}}}, Const{Float64}, Const{Float64}}

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
  [2] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
  [3] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6682 [inlined]
  [4] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Core.kwcall), df::Nothing, primal_1::@NamedTuple{…}, shadow_1_1::Base.RefValue{…}, primal_2::typeof(quadgk), shadow_2_1::Nothing, primal_3::var"#f#17"{…}, shadow_3_1::Base.RefValue{…}, primal_4::Float64, shadow_4_1::Nothing, primal_5::Float64, shadow_5_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/TiboG/src/rules/jitrules.jl:338
  [5] polyintegral
    @ ~/issues/quadgkmixed.jl:20 [inlined]
  [6] polyintegral
    @ ~/issues/quadgkmixed.jl:0 [inlined]
  [7] diffejulia_polyintegral_13503_inner_1wrap
    @ ~/issues/quadgkmixed.jl:0
  [8] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
  [9] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
 [10] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6671 [inlined]
 [11] autodiff
    @ ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:320 [inlined]
 [12] autodiff(::ReverseMode{…}, ::typeof(polyintegral), ::Type{…}, ::Active{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:332
 [13] macro expansion
    @ show.jl:1181 [inlined]
 [14] top-level scope
    @ ~/issues/quadgkmixed.jl:26
 [15] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [16] top-level scope
    @ REPL[8]:1
in expression starting at /home/daniel/issues/quadgkmixed.jl:26
Some type information was truncated. Use `show(err)` to see complete types.

It's unclear to me whether this is specific to the implementation of QuadGK's custom rule or is an issue with the custom rule machinery in general, but it looked to me like it might be the latter so I'm filing here.

wsmoses commented 3 days ago

So the code is type unstable so activity analysis couldn't prove that one didn't need the derivative of config (see the runtime_generic_augfwd which is our handler for a type unstable call)

danielwe commented 3 days ago

The differentiated function is fully inferred:

julia> @code_warntype polyintegral(coeffs, config)
MethodInstance for polyintegral(::Tuple{Float64}, ::@NamedTuple{rtol::Float64})
  from polyintegral(coeffs, config) @ Main ~/issues/quadgkkwargs.jl:18
Arguments
  #self#::Core.Const(polyintegral)
  coeffs::Tuple{Float64}
  config::@NamedTuple{rtol::Float64}
Locals
  f::var"#f#1"{Tuple{Float64}}
  @_5::Tuple{Float64, Float64}
Body::Float64
1 ─ %1  = Main.:(var"#f#1")::Core.Const(var"#f#1")
│   %2  = Core.typeof(coeffs)::Core.Const(Tuple{Float64})
│   %3  = Core.apply_type(%1, %2)::Core.Const(var"#f#1"{Tuple{Float64}})
│         (f = %new(%3, coeffs))
│   %5  = Base.NamedTuple()::Core.Const(NamedTuple())
│   %6  = Base.merge(%5, config)::@NamedTuple{rtol::Float64}
│   %7  = Base.isempty(%6)::Core.Const(false)
└──       goto #3 if not %7
2 ─       Core.Const(:(@_5 = Main.quadgk(f, -1.0, 1.0)))
└──       Core.Const(:(goto %12))
3 ┄       (@_5 = Core.kwcall(%6, Main.quadgk, f, -1.0, 1.0))
│   %12 = @_5::Tuple{Float64, Float64}
│   %13 = Main.first(%12)::Float64
└──       return %13

Is the type instability introduced in the custom rule?

wsmoses commented 3 days ago

No, enzymerules don't add type instabilities

danielwe commented 3 days ago

Then I'm confused... Is the issue that the inferred types are lost on Enzyme somehow?

wsmoses commented 3 days ago

what happens if you use code_typed

danielwe commented 3 days ago

All blue:

julia> @code_typed polyintegral(coeffs, config)
CodeInfo(
1 ─ %1 = %new(var"#f#5"{Tuple{Float64}}, coeffs)::var"#f#5"{Tuple{Float64}}
│   %2 = Core.getfield(config, :rtol)::Float64
│   %3 = QuadGK.nothing::Nothing
│   %4 = QuadGK.norm::typeof(LinearAlgebra.norm)
│   %5 = QuadGK.nothing::Nothing
│   %6 = QuadGK.nothing::Nothing
│   %7 = %new(QuadGK.var"#50#51"{Nothing, Float64, Int64, Int64, typeof(LinearAlgebra.norm), Nothing, Nothing}, %3, %2, 10000000, 7, %4, %5, %6)::QuadGK.var"#50#51"{Nothing, Float64, Int64, Int64, typeof(LinearAlgebra.norm), Nothing, Nothing}
│   %8 = invoke QuadGK.handle_infinities(%7::QuadGK.var"#50#51"{Nothing, Float64, Int64, Int64, typeof(LinearAlgebra.norm), Nothing, Nothing}, %1::var"#f#5"{Tuple{Float64}}, (-1.0, 1.0)::Tuple{Float64, Float64})::Tuple{Float64, Float64}
│   %9 = Base.getfield(%8, 1, true)::Float64
└──      return %9
) => Float64

Also, for good measure:

julia> Base.promote_op(polyintegral, Tuple{Float64}, @NamedTuple{rtol::Float64})
Float64

julia> Core.Compiler.return_type(polyintegral, Tuple{Tuple{Float64},@NamedTuple{rtol::Float64}})
Float64
danielwe commented 3 days ago

So looks like in code_typed, there is no kwarg handling because quadgk is inlined and the following do block is inserted directly in the body of polyintegral: https://github.com/JuliaMath/QuadGK.jl/blob/ce727e15f76df016ee2db9819fab0b4a7c6117fe/src/api.jl#L82-L84. What does Enzyme do in such cases where the call for which there is a custom rule has been inlined? Does it have to intercept earlier in the compiler pipeline, before inlining, where types might not be fully known?

Also, there's an invoke in front of the call to handle_infinities, I'm not sure what that means (there's no @invoke or invokelatest in the code) and whether it could matter.

Looking at this makes me wonder whether the inner function do_quadgk, which does not have kwargs and isn't inlined here, might be an even better target for the custom rule? It too supports returning the segbuf by using the ReturnSegbuf wrapper.

danielwe commented 3 days ago

Re the above musings: Adding @noinline to quadgk (line 80 following the link in the previous post) changes @code_typed the way you would expect but doesn't change the error in the MWE at all. And it looks like invoke is just SSA speak for a statically dispatched call to a methodinstance, which is obviously nothing to worry about. Clearly, I don't have much to come with here.

danielwe commented 1 day ago

I'm at a loss trying to debug this further. The MWE can be minimized to just this:

using Enzyme, QuadGK

constantintegral(a, rtol) = first(quadgk(_ -> a, -1.0, 1.0; rtol))
@show autodiff(Reverse, constantintegral, Active, Active(1.0), Const(1e-10))

It's not a general issue with float kwargs and custom rules, I confirmed that by adding a float kwarg to the example from the custom rule tutorial and that worked fine.

I assume the call should never have gone to runtime_generic_augfwd since the function is stable and inferred, but I can't figure out where and how Enzyme decides to go that route. In any case, the fate of erroring is ultimately sealed in the following lines, but beyond that, the trail goes cold, excuse me, ccalled, as activep is set by a ccall to Enzyme proper.

https://github.com/EnzymeAD/Enzyme.jl/blob/6e867ba81bab2abafaed85f56a0f6e7cc38b01a2/src/rules/customrules.jl#L92-L98

danielwe commented 1 day ago

Tried making the argument to the outer function an integer (bc inactive type) and computing the float rtol inside, but no dice---the following throws the same error:

using Enzyme, QuadGK

constantintegral(a, logrtol) = first(quadgk(_ -> a, -1.0, 1.0; rtol=exp(logrtol)))
@show autodiff(Reverse, constantintegral, Active, Active(1.0), Const(-23))
danielwe commented 1 day ago

Progress: The reason Enzyme takes the runtime path is that quadgk isn't specialized on the function argument (due to the notorious https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing). Adding a type parameter quadgk(f::F, ...; ...) where {F,...} at https://github.com/JuliaMath/QuadGK.jl/blob/master/src/api.jl#L80-L81 addresses the immediate issue. I'm sure it wouldn't be a problem to get that merged into QuadGK.

However:

  1. This seems brittle: Enzyme basically doesn't support custom rules for functions that take both function args and kwargs, unless the they happen to be implemented in a way that forces specialization. The same is likely true for vararg + kwargs and type + kwargs. Is there a way for Enzyme to force specialization of every method for functions that have custom rules?
  2. The MWE still fails, now with an error that looks related to compiling the reverse rule. Stacktrace:
    
    ERROR: Enzyme compilation failed.
    [big info dump, see below]

Stacktrace: [1] multiple call sites @ unknown:0

Stacktrace: [1] (::Enzyme.Compiler.var"#getparent#18860"{…})(v::LLVM.Argument, offset::LLVM.ConstantInt, hasload::Bool) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler/optimize.jl:833 [2] nodecayed_phis!(mod::LLVM.Module) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler/optimize.jl:836 [3] optimize! @ ~/.julia/packages/Enzyme/uXW2v/src/compiler/optimize.jl:2143 [inlined] [4] nested_codegen!(mode::Enzyme.API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world::UInt64) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:1889 [5] nested_codegen!(mode::Enzyme.API.CDerivativeMode, mod::LLVM.Module, f::Function, tt::Type, world::UInt64) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:1827 [6] enzyme_custom_common_rev @ ~/.julia/packages/Enzyme/uXW2v/src/rules/customrules.jl:753 [inlined] [7] enzyme_custom_rev(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, tape::LLVM.ExtractValueInst) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/rules/customrules.jl:1138 [8] enzyme_custom_rev_cfunc(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, tape::Ptr{LLVM.API.LLVMOpaqueValue}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/rules/llvmrules.jl:27 [9] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool) @ Enzyme.API ~/.julia/packages/Enzyme/uXW2v/src/api.jl:163 [10] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:4004 [11] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6308 [12] codegen @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:5465 [inlined] [13] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7110 [14] _thunk @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7110 [inlined] [15] cached_compilation @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7151 [inlined] [16] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7224 [17] #s2084#19056 @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7266 [inlined] [18] @ Enzyme.Compiler ./none:0 [19] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any}) @ Core ./boot.jl:602 [20] autodiff @ ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:311 [inlined] [21] autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::typeof(constantintegral), ::Type{Active}, ::Active{Float64}, ::Const{Float64}) @ Enzyme ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:328 [22] macro expansion @ show.jl:1181 [inlined] [23] top-level scope @ ~/issues/quadgkkwargs.jl:5 [24] include(fname::String) @ Base.MainInclude ./client.jl:489 [25] top-level scope @ REPL[4]:1 in expression starting at /home/daniel/issues/quadgkkwargs.jl:5 Some type information was truncated. Use show(err) to see complete types.


<details><summary>Dumped info</summary>
<p>

Current scope: define internal fastcc [1 x [1 x double]] @julia_108865(double %0, [1 x [1 x double]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %1) unnamed_addr #0 !dbg !113 { top: %2 = call {}* @julia.get_pgcstack() %current_task189 = getelementptr inbounds {}*, {} %2, i64 -14 %current_task1 = bitcast {} %current_task189 to {} %ptls_field90 = getelementptr inbounds {}, {} %2, i64 2 %3 = bitcast {} %ptls_field90 to i64 %ptls_load9192 = load i64, i64** %3, align 8, !tbaa !117 %4 = getelementptr inbounds i64, i64 %ptls_load9192, i64 2 %safepoint = load i64*, i64 %4, align 8, !tbaa !121 fence syncscope("singlethread") seq_cst call void @julia.safepoint(i64 %safepoint), !dbg !123 fence syncscope("singlethread") seq_cst %5 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32, {} addrspace(10)), {} addrspace(10), {} addrspace(10), ...) @julia.call2({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32, {} addrspace(10)) noundef nonnull @ijl_invoke, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124614641495072 to {}) to {} addrspace(10)), {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616757572208 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615702762608 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124616756638464 to {}) to {} addrspace(10))), !dbg !124 %6 = call {} addrspace(10) @ijl_get_nth_field_checked({} addrspace(10) nonnull %5, i64 noundef 0), !dbg !124 %7 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @ijl_apply_generic, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616783307456 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549797712 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549797216 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549799392 to {}) to {} addrspace(10)), {} addrspace(10) nonnull %6), !dbg !135 %8 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10)*, i32) noundef nonnull @jl_fsvec_ref, {} addrspace(10) noundef null, {} addrspace(10) nonnull %7, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616983069312 to {}) to {} addrspace(10))), !dbg !137 %9 = call {} addrspace(10) @julia.typeof({} addrspace(10) nonnull %8) #114, !dbg !137 %10 = addrspacecast {} addrspace(10) %9 to {} addrspace(11), !dbg !137 %11 = call nonnull {} @julia.pointer_from_objref({} addrspace(11) %10) #114, !dbg !137 %exactly_isa.not = icmp eq {} %11, inttoptr (i64 124616754844864 to {}), !dbg !137 br i1 %exactly_isa.not, label %post_box_union, label %L13, !dbg !137

L13: ; preds = %top %magicptr = ptrtoint {}* %11 to i64, !dbg !137 switch i64 %magicptr, label %L26 [ i64 124616826846176, label %L15 i64 124615509956096, label %is48 ], !dbg !137

L15: ; preds = %L13 %12 = call fastcc i8 @julia____108839({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %8), !dbg !137, !range !138 br label %pass, !dbg !137

L26: ; preds = %L13, %pass %13 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @jl_f__svec_ref, {} addrspace(10) noundef null, {} addrspace(10) nonnull %7, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616983069312 to {}) to {} addrspace(10))), !dbg !137 %14 = call {} addrspace(10) @julia.typeof({} addrspace(10) nonnull %13) #114, !dbg !137 %15 = addrspacecast {} addrspace(10) %14 to {} addrspace(11), !dbg !137 %16 = call nonnull {} @julia.pointer_from_objref({} addrspace(11) %15) #114, !dbg !137 %exactly_isa25.not = icmp eq {} %16, inttoptr (i64 124616754844864 to {}*), !dbg !137 br i1 %exactly_isa25.not, label %post_box_union32, label %L31, !dbg !137

L31: ; preds = %L26 %magicptr96 = ptrtoint {}* %16 to i64, !dbg !137 switch i64 %magicptr96, label %L44 [ i64 124616826846176, label %L33 i64 124615509956096, label %is ], !dbg !137

L33: ; preds = %L31 %17 = call fastcc i8 @julia____108839({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %13), !dbg !137, !range !138 br label %pass37, !dbg !137

L44: ; preds = %L31, %pass37 %18 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @jl_f__svec_ref, {} addrspace(10) noundef null, {} addrspace(10) nonnull %7, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616983069312 to {}) to {} addrspace(10))), !dbg !139 %19 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @ijl_apply_generic, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616755884736 to {}) to {} addrspace(10)), {} addrspace(10) nonnull %18), !dbg !139 %20 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @ijl_apply_generic, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616783307456 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549798176 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549797216 to {}) to {} addrspace(10)), {} addrspace(10) nonnull %19), !dbg !140 %21 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @jl_f__svec_ref, {} addrspace(10) noundef null, {} addrspace(10) nonnull %20, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616983069024 to {}) to {} addrspace(10)*)), !dbg !140 br label %L51, !dbg !139

L51: ; preds = %pass, %pass37, %L44 %value_phi = phi {} addrspace(10) [ %21, %L44 ], [ addrspacecast ({} inttoptr (i64 124615834498672 to {}) to {} addrspace(10)), %pass37 ], [ addrspacecast ({} inttoptr (i64 124615834498672 to {}) to {} addrspace(10)), %pass ] %22 = call {} addrspace(10) @julia.typeof({} addrspace(10) %value_phi) #114, !dbg !141 %23 = addrspacecast {} addrspace(10) %22 to {} addrspace(11), !dbg !141 %24 = call nonnull {} @julia.pointer_from_objref({} addrspace(11) %23) #114, !dbg !141 %exactly_isa6.not = icmp eq {} %24, inttoptr (i64 124615509956096 to {}*), !dbg !141 br i1 %exactly_isa6.not, label %L68, label %L59, !dbg !141

L59: ; preds = %L51 %25 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @ijl_apply_generic, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616754836768 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615509956096 to {}) to {} addrspace(10)), {} addrspace(10) %value_phi), !dbg !141 br label %L68, !dbg !141

L68: ; preds = %L51, %L59 %unbox20.in.in = phi {} addrspace(10) [ %25, %L59 ], [ %value_phi, %L51 ] %unbox20.in = bitcast {} addrspace(10) %unbox20.in.in to i32 addrspace(10) %unbox20 = load i32, i32 addrspace(10) %unbox20.in, align 4, !dbg !142, !tbaa !146, !alias.scope !149, !noalias !152 %26 = and i32 %unbox20, -3, !dbg !145 %27 = icmp eq i32 %26, 0, !dbg !145 br i1 %27, label %L81, label %L76, !dbg !132

L76: ; preds = %L68 %memcpy_refined_src = getelementptr inbounds [1 x [1 x double]], [1 x [1 x double]] addrspace(11) %1, i64 0, i64 0, i64 0, !dbg !157 %28 = load double, double addrspace(11) %memcpy_refined_src, align 8, !dbg !157, !tbaa !121, !alias.scope !158, !noalias !159 %box = call noalias nonnull dereferenceable(24) {} addrspace(10)* @julia.gc_alloc_obj({} nonnull %current_task1, i64 noundef 24, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124615414846864 to {}) to {} addrspace(10))) #115, !dbg !157 %29 = bitcast {} addrspace(10) %box to i8 addrspace(10), !dbg !157 %newstruct11.sroa.0.0..sroa_cast = bitcast {} addrspace(10) %box to double addrspace(10), !dbg !157 store double %28, double addrspace(10) %newstruct11.sroa.0.0..sroa_cast, align 8, !dbg !157, !tbaa !160, !alias.scope !161, !noalias !162 %newstruct11.sroa.2.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10) %29, i64 8, !dbg !157 %newstruct11.sroa.2.0..sroa_cast = bitcast i8 addrspace(10) %newstruct11.sroa.2.0..sroa_idx to double addrspace(10), !dbg !157 store double %28, double addrspace(10) %newstruct11.sroa.2.0..sroa_cast, align 8, !dbg !157, !tbaa !160, !alias.scope !161, !noalias !162 %newstruct11.sroa.3.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10) %29, i64 16, !dbg !157 %newstruct11.sroa.3.0..sroa_cast = bitcast i8 addrspace(10) %newstruct11.sroa.3.0..sroa_idx to double addrspace(10), !dbg !157 store double %0, double addrspace(10) %newstruct11.sroa.3.0..sroa_cast, align 8, !dbg !157, !tbaa !160, !alias.scope !161, !noalias !162 %30 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32, {} addrspace(10)), {} addrspace(10), {} addrspace(10), ...) @julia.call2({} addrspace(10) ({} addrspace(10), {} addrspace(10)*, i32, {} addrspace(10)) noundef nonnull @ijl_invoke, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124615444310192 to {}) to {} addrspace(10)), {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616757572208 to {}) to {} addrspace(10)), {} addrspace(10) nonnull %box, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616756638464 to {}) to {} addrspace(10))), !dbg !157 %31 = addrspacecast {} addrspace(10) %30 to [1 x [1 x double]] addrspace(11), !dbg !165 %32 = bitcast {} addrspace(10) %30 to [1 x [1 x double]] addrspace(10) br label %L81

L81: ; preds = %L68, %L76 %nodecayed..pn = phi [1 x [1 x double]] addrspace(10) %nodecayedoff..pn = phi i64 %.pn = phi [1 x [1 x double]] addrspace(11) [ %31, %L76 ], [ %1, %L68 ] %.sroa.084.0.in = getelementptr inbounds [1 x [1 x double]], [1 x [1 x double]] addrspace(11) %.pn, i64 0, i64 0, i64 0 %.sroa.084.0 = load double, double addrspace(11) %.sroa.084.0.in, align 8, !tbaa !160, !alias.scope !169, !noalias !170 %unbox10.fca.0.0.insert = insertvalue [1 x [1 x double]] poison, double %.sroa.084.0, 0, 0, !dbg !134 ret [1 x [1 x double]] %unbox10.fca.0.0.insert, !dbg !134

post_box_union: ; preds = %top call void @ijl_type_error(i8 noundef getelementptr inbounds ([3 x i8], [3 x i8] @_j_str4, i32 0, i32 0), {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616826846592 to {}) to {} addrspace(10)), {} addrspace(12) noundef addrspacecast ({} inttoptr (i64 124616754844928 to {}) to {} addrspace(12))) #116, !dbg !137 unreachable, !dbg !137

pass: ; preds = %is48, %L15 %unbox4.ph = phi i8 [ %12, %L15 ], [ %phi.cast97, %is48 ] %.not93 = icmp eq i8 %unbox4.ph, 0, !dbg !137 br i1 %.not93, label %L26, label %L51, !dbg !137

post_box_union32: ; preds = %L26 call void @ijl_type_error(i8 noundef getelementptr inbounds ([3 x i8], [3 x i8] @_j_str4, i32 0, i32 0), {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616826846592 to {}) to {} addrspace(10)), {} addrspace(12) noundef addrspacecast ({} inttoptr (i64 124616754844928 to {}) to {} addrspace(12))) #116, !dbg !137 unreachable, !dbg !137

pass37: ; preds = %is, %L33 %unbox38.ph = phi i8 [ %phi.cast, %is ], [ %17, %L33 ] %.not95 = icmp eq i8 %unbox38.ph, 0, !dbg !137 br i1 %.not95, label %L44, label %L51, !dbg !137

is: ; preds = %L31 %33 = bitcast {} addrspace(10) %13 to i32 addrspace(10), !dbg !171 %unbox43 = load i32, i32 addrspace(10)* %33, align 4, !dbg !171, !tbaa !146, !alias.scope !149, !noalias !152 %34 = icmp eq i32 %unbox43, 0, !dbg !171 %phi.cast = zext i1 %34 to i8, !dbg !171 br label %pass37, !dbg !171

is48: ; preds = %L13 %35 = bitcast {} addrspace(10) %8 to i32 addrspace(10), !dbg !171 %unbox50 = load i32, i32 addrspace(10)* %35, align 4, !dbg !171, !tbaa !146, !alias.scope !149, !noalias !152 %36 = icmp eq i32 %unbox50, 0, !dbg !171 %phi.cast97 = zext i1 %36 to i8, !dbg !171 br label %pass, !dbg !171 }

Could not analyze garbage collection behavior of inst: %.pn = phi [1 x [1 x double]] addrspace(11) [ %31, %L76 ], [ %1, %L68 ] v0: [1 x [1 x double]] addrspace(11) %1 v: [1 x [1 x double]] addrspace(11)* %1 offset: i64 0 hasload: false



</p>
</details> 
danielwe commented 1 day ago
  1. Is there a way for Enzyme to force specialization of every method for functions that have custom rules?

A quick look at the devdocs suggests something like this should be possible using the method's generator field: https://docs.julialang.org/en/v1/devdocs/ast/#ast-lowered-method. Seems like this would be beneficial even for cases that don't currently error, but needlessly end up in runtime handlers.

danielwe commented 23 hours ago

2. The MWE still fails, now with an error that looks related to compiling the reverse rule.

This is even true without the kwarg usage that triggered the original error. That is, with QuadGK.jl patched as described in the 2nd to last comment, even the simplest QuadGK call, like the example below, fails with the same error:

using Enzyme, QuadGK

constantintegral(a) = first(quadgk(_ -> a, -1.0, 1.0))
@show autodiff(Reverse, constantintegral, Active, Active(1.0))
danielwe commented 18 hours ago

OK, here's one way to get around the second error: add @inline to the call to do_quadgk. So to get around both issues, apply the following diff to QuadGK.jl:

diff --git a/src/api.jl b/src/api.jl
index b375260..875d3c2 100644
--- a/src/api.jl
+++ b/src/api.jl
@@ -77,10 +77,10 @@ derivatives of the approximate integral).
 quadgk(f, segs...; kws...) =
     quadgk(f, promote(segs...)...; kws...)

-function quadgk(f, segs::T...;
-       atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing, eval_segbuf=nothing) where {T}
+function quadgk(f::F, segs::T...;
+       atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing, eval_segbuf=nothing) where {F,T}
     handle_infinities(f, segs) do f, s, _
-        do_quadgk(f, s, order, atol, rtol, maxevals, norm, segbuf, eval_segbuf)
+        @inline do_quadgk(f, s, order, atol, rtol, maxevals, norm, segbuf, eval_segbuf)
     end
 end

The @inline macro can also be placed at the call to handle_infinities; either works.

So to reiterate:

Happy to have found a workaround, but it's a bit unfortunate that it requires patching a well-tested and type-stable implementation that isn't doing anything particularly exotic. Any hope of Enzyme handling these cases in the future?

danielwe commented 3 hours ago

What also works is to add @noinline on quadgk itself, as in the diff below. So maybe the important thing is that quadgk does not get inlined, and the reason @inline do_quadgk also works is that it makes quadgk big enough to avoid automatic inlining.

diff --git a/src/api.jl b/src/api.jl
index b375260..17de777 100644
--- a/src/api.jl
+++ b/src/api.jl
@@ -77,8 +77,8 @@ derivatives of the approximate integral).
 quadgk(f, segs...; kws...) =
     quadgk(f, promote(segs...)...; kws...)

-function quadgk(f, segs::T...;
-       atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing, eval_segbuf=nothing) where {T}
+@noinline function quadgk(f::F, segs::T...;
+       atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing, eval_segbuf=nothing) where {F,T}
     handle_infinities(f, segs) do f, s, _
         do_quadgk(f, s, order, atol, rtol, maxevals, norm, segbuf, eval_segbuf)
     end