Closed danielwe closed 2 months ago
So the code is type unstable so activity analysis couldn't prove that one didn't need the derivative of config (see the runtime_generic_augfwd
which is our handler for a type unstable call)
The differentiated function is fully inferred:
julia> @code_warntype polyintegral(coeffs, config)
MethodInstance for polyintegral(::Tuple{Float64}, ::@NamedTuple{rtol::Float64})
from polyintegral(coeffs, config) @ Main ~/issues/quadgkkwargs.jl:18
Arguments
#self#::Core.Const(polyintegral)
coeffs::Tuple{Float64}
config::@NamedTuple{rtol::Float64}
Locals
f::var"#f#1"{Tuple{Float64}}
@_5::Tuple{Float64, Float64}
Body::Float64
1 ─ %1 = Main.:(var"#f#1")::Core.Const(var"#f#1")
│ %2 = Core.typeof(coeffs)::Core.Const(Tuple{Float64})
│ %3 = Core.apply_type(%1, %2)::Core.Const(var"#f#1"{Tuple{Float64}})
│ (f = %new(%3, coeffs))
│ %5 = Base.NamedTuple()::Core.Const(NamedTuple())
│ %6 = Base.merge(%5, config)::@NamedTuple{rtol::Float64}
│ %7 = Base.isempty(%6)::Core.Const(false)
└── goto #3 if not %7
2 ─ Core.Const(:(@_5 = Main.quadgk(f, -1.0, 1.0)))
└── Core.Const(:(goto %12))
3 ┄ (@_5 = Core.kwcall(%6, Main.quadgk, f, -1.0, 1.0))
│ %12 = @_5::Tuple{Float64, Float64}
│ %13 = Main.first(%12)::Float64
└── return %13
Is the type instability introduced in the custom rule?
No, enzymerules don't add type instabilities
Then I'm confused... Is the issue that the inferred types are lost on Enzyme somehow?
what happens if you use code_typed
All blue:
julia> @code_typed polyintegral(coeffs, config)
CodeInfo(
1 ─ %1 = %new(var"#f#5"{Tuple{Float64}}, coeffs)::var"#f#5"{Tuple{Float64}}
│ %2 = Core.getfield(config, :rtol)::Float64
│ %3 = QuadGK.nothing::Nothing
│ %4 = QuadGK.norm::typeof(LinearAlgebra.norm)
│ %5 = QuadGK.nothing::Nothing
│ %6 = QuadGK.nothing::Nothing
│ %7 = %new(QuadGK.var"#50#51"{Nothing, Float64, Int64, Int64, typeof(LinearAlgebra.norm), Nothing, Nothing}, %3, %2, 10000000, 7, %4, %5, %6)::QuadGK.var"#50#51"{Nothing, Float64, Int64, Int64, typeof(LinearAlgebra.norm), Nothing, Nothing}
│ %8 = invoke QuadGK.handle_infinities(%7::QuadGK.var"#50#51"{Nothing, Float64, Int64, Int64, typeof(LinearAlgebra.norm), Nothing, Nothing}, %1::var"#f#5"{Tuple{Float64}}, (-1.0, 1.0)::Tuple{Float64, Float64})::Tuple{Float64, Float64}
│ %9 = Base.getfield(%8, 1, true)::Float64
└── return %9
) => Float64
Also, for good measure:
julia> Base.promote_op(polyintegral, Tuple{Float64}, @NamedTuple{rtol::Float64})
Float64
julia> Core.Compiler.return_type(polyintegral, Tuple{Tuple{Float64},@NamedTuple{rtol::Float64}})
Float64
So looks like in code_typed, there is no kwarg handling because quadgk
is inlined and the following do block is inserted directly in the body of polyintegral: https://github.com/JuliaMath/QuadGK.jl/blob/ce727e15f76df016ee2db9819fab0b4a7c6117fe/src/api.jl#L82-L84. What does Enzyme do in such cases where the call for which there is a custom rule has been inlined? Does it have to intercept earlier in the compiler pipeline, before inlining, where types might not be fully known?
Also, there's an invoke
in front of the call to handle_infinities
, I'm not sure what that means (there's no @invoke
or invokelatest
in the code) and whether it could matter.
Looking at this makes me wonder whether the inner function do_quadgk
, which does not have kwargs and isn't inlined here, might be an even better target for the custom rule? It too supports returning the segbuf by using the ReturnSegbuf wrapper.
Re the above musings: Adding @noinline
to quadgk
(line 80 following the link in the previous post) changes @code_typed
the way you would expect but doesn't change the error in the MWE at all. And it looks like invoke
is just SSA speak for a statically dispatched call to a methodinstance, which is obviously nothing to worry about. Clearly, I don't have much to come with here.
I'm at a loss trying to debug this further. The MWE can be minimized to just this:
using Enzyme, QuadGK
constantintegral(a, rtol) = first(quadgk(_ -> a, -1.0, 1.0; rtol))
@show autodiff(Reverse, constantintegral, Active, Active(1.0), Const(1e-10))
It's not a general issue with float kwargs and custom rules, I confirmed that by adding a float kwarg to the example from the custom rule tutorial and that worked fine.
I assume the call should never have gone to runtime_generic_augfwd
since the function is stable and inferred, but I can't figure out where and how Enzyme decides to go that route. In any case, the fate of erroring is ultimately sealed in the following lines, but beyond that, the trail goes cold, excuse me, ccalled, as activep
is set by a ccall to Enzyme proper.
Tried making the argument to the outer function an integer (bc inactive type) and computing the float rtol
inside, but no dice---the following throws the same error:
using Enzyme, QuadGK
constantintegral(a, logrtol) = first(quadgk(_ -> a, -1.0, 1.0; rtol=exp(logrtol)))
@show autodiff(Reverse, constantintegral, Active, Active(1.0), Const(-23))
Progress: The reason Enzyme takes the runtime path is that quadgk
isn't specialized on the function argument (due to the notorious https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing). Adding a type parameter quadgk(f::F, ...; ...) where {F,...}
at https://github.com/JuliaMath/QuadGK.jl/blob/master/src/api.jl#L80-L81 addresses the immediate issue. I'm sure it wouldn't be a problem to get that merged into QuadGK.
However:
ERROR: Enzyme compilation failed.
[big info dump, see below]
Stacktrace: [1] multiple call sites @ unknown:0
Stacktrace:
[1] (::Enzyme.Compiler.var"#getparent#18860"{…})(v::LLVM.Argument, offset::LLVM.ConstantInt, hasload::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler/optimize.jl:833
[2] nodecayed_phis!(mod::LLVM.Module)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler/optimize.jl:836
[3] optimize!
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler/optimize.jl:2143 [inlined]
[4] nested_codegen!(mode::Enzyme.API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world::UInt64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:1889
[5] nested_codegen!(mode::Enzyme.API.CDerivativeMode, mod::LLVM.Module, f::Function, tt::Type, world::UInt64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:1827
[6] enzyme_custom_common_rev
@ ~/.julia/packages/Enzyme/uXW2v/src/rules/customrules.jl:753 [inlined]
[7] enzyme_custom_rev(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, tape::LLVM.ExtractValueInst)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/rules/customrules.jl:1138
[8] enzyme_custom_rev_cfunc(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, tape::Ptr{LLVM.API.LLVMOpaqueValue})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/rules/llvmrules.jl:27
[9] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/uXW2v/src/api.jl:163
[10] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:4004
[11] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6308
[12] codegen
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:5465 [inlined]
[13] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7110
[14] _thunk
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7110 [inlined]
[15] cached_compilation
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7151 [inlined]
[16] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7224
[17] #s2084#19056
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7266 [inlined]
[18]
@ Enzyme.Compiler ./none:0
[19] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[20] autodiff
@ ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:311 [inlined]
[21] autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::typeof(constantintegral), ::Type{Active}, ::Active{Float64}, ::Const{Float64})
@ Enzyme ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:328
[22] macro expansion
@ show.jl:1181 [inlined]
[23] top-level scope
@ ~/issues/quadgkkwargs.jl:5
[24] include(fname::String)
@ Base.MainInclude ./client.jl:489
[25] top-level scope
@ REPL[4]:1
in expression starting at /home/daniel/issues/quadgkkwargs.jl:5
Some type information was truncated. Use show(err)
to see complete types.
<details><summary>Dumped info</summary>
<p>
Current scope: define internal fastcc [1 x [1 x double]] @julia_108865(double %0, [1 x [1 x double]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %1) unnamed_addr #0 !dbg !113 { top: %2 = call {}* @julia.get_pgcstack() %current_task189 = getelementptr inbounds {}*, {} %2, i64 -14 %current_task1 = bitcast {} %current_task189 to {} %ptls_field90 = getelementptr inbounds {}, {} %2, i64 2 %3 = bitcast {} %ptls_field90 to i64 %ptls_load9192 = load i64, i64** %3, align 8, !tbaa !117 %4 = getelementptr inbounds i64, i64 %ptls_load9192, i64 2 %safepoint = load i64*, i64 %4, align 8, !tbaa !121 fence syncscope("singlethread") seq_cst call void @julia.safepoint(i64 %safepoint), !dbg !123 fence syncscope("singlethread") seq_cst %5 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32, {} addrspace(10)), {} addrspace(10), {} addrspace(10), ...) @julia.call2({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32, {} addrspace(10)) noundef nonnull @ijl_invoke, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124614641495072 to {}) to {} addrspace(10)), {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616757572208 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615702762608 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124616756638464 to {}) to {} addrspace(10))), !dbg !124 %6 = call {} addrspace(10) @ijl_get_nth_field_checked({} addrspace(10) nonnull %5, i64 noundef 0), !dbg !124 %7 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @ijl_apply_generic, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616783307456 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549797712 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549797216 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549799392 to {}) to {} addrspace(10)), {} addrspace(10) nonnull %6), !dbg !135 %8 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10)*, i32) noundef nonnull @jl_fsvec_ref, {} addrspace(10) noundef null, {} addrspace(10) nonnull %7, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616983069312 to {}) to {} addrspace(10))), !dbg !137 %9 = call {} addrspace(10) @julia.typeof({} addrspace(10) nonnull %8) #114, !dbg !137 %10 = addrspacecast {} addrspace(10) %9 to {} addrspace(11), !dbg !137 %11 = call nonnull {} @julia.pointer_from_objref({} addrspace(11) %10) #114, !dbg !137 %exactly_isa.not = icmp eq {} %11, inttoptr (i64 124616754844864 to {}), !dbg !137 br i1 %exactly_isa.not, label %post_box_union, label %L13, !dbg !137
L13: ; preds = %top %magicptr = ptrtoint {}* %11 to i64, !dbg !137 switch i64 %magicptr, label %L26 [ i64 124616826846176, label %L15 i64 124615509956096, label %is48 ], !dbg !137
L15: ; preds = %L13 %12 = call fastcc i8 @julia____108839({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %8), !dbg !137, !range !138 br label %pass, !dbg !137
L26: ; preds = %L13, %pass %13 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @jl_f__svec_ref, {} addrspace(10) noundef null, {} addrspace(10) nonnull %7, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616983069312 to {}) to {} addrspace(10))), !dbg !137 %14 = call {} addrspace(10) @julia.typeof({} addrspace(10) nonnull %13) #114, !dbg !137 %15 = addrspacecast {} addrspace(10) %14 to {} addrspace(11), !dbg !137 %16 = call nonnull {} @julia.pointer_from_objref({} addrspace(11) %15) #114, !dbg !137 %exactly_isa25.not = icmp eq {} %16, inttoptr (i64 124616754844864 to {}*), !dbg !137 br i1 %exactly_isa25.not, label %post_box_union32, label %L31, !dbg !137
L31: ; preds = %L26 %magicptr96 = ptrtoint {}* %16 to i64, !dbg !137 switch i64 %magicptr96, label %L44 [ i64 124616826846176, label %L33 i64 124615509956096, label %is ], !dbg !137
L33: ; preds = %L31 %17 = call fastcc i8 @julia____108839({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %13), !dbg !137, !range !138 br label %pass37, !dbg !137
L44: ; preds = %L31, %pass37 %18 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @jl_f__svec_ref, {} addrspace(10) noundef null, {} addrspace(10) nonnull %7, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616983069312 to {}) to {} addrspace(10))), !dbg !139 %19 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @ijl_apply_generic, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616755884736 to {}) to {} addrspace(10)), {} addrspace(10) nonnull %18), !dbg !139 %20 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @ijl_apply_generic, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616783307456 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549798176 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615549797216 to {}) to {} addrspace(10)), {} addrspace(10) nonnull %19), !dbg !140 %21 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @jl_f__svec_ref, {} addrspace(10) noundef null, {} addrspace(10) nonnull %20, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616983069024 to {}) to {} addrspace(10)*)), !dbg !140 br label %L51, !dbg !139
L51: ; preds = %pass, %pass37, %L44 %value_phi = phi {} addrspace(10) [ %21, %L44 ], [ addrspacecast ({} inttoptr (i64 124615834498672 to {}) to {} addrspace(10)), %pass37 ], [ addrspacecast ({} inttoptr (i64 124615834498672 to {}) to {} addrspace(10)), %pass ] %22 = call {} addrspace(10) @julia.typeof({} addrspace(10) %value_phi) #114, !dbg !141 %23 = addrspacecast {} addrspace(10) %22 to {} addrspace(11), !dbg !141 %24 = call nonnull {} @julia.pointer_from_objref({} addrspace(11) %23) #114, !dbg !141 %exactly_isa6.not = icmp eq {} %24, inttoptr (i64 124615509956096 to {}*), !dbg !141 br i1 %exactly_isa6.not, label %L68, label %L59, !dbg !141
L59: ; preds = %L51 %25 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10)*, {} addrspace(10), i32), {} addrspace(10), ...) @julia.call({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32) noundef nonnull @ijl_apply_generic, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616754836768 to {}) to {} addrspace(10)), {} addrspace(10) addrspacecast ({} inttoptr (i64 124615509956096 to {}) to {} addrspace(10)), {} addrspace(10) %value_phi), !dbg !141 br label %L68, !dbg !141
L68: ; preds = %L51, %L59 %unbox20.in.in = phi {} addrspace(10) [ %25, %L59 ], [ %value_phi, %L51 ] %unbox20.in = bitcast {} addrspace(10) %unbox20.in.in to i32 addrspace(10) %unbox20 = load i32, i32 addrspace(10) %unbox20.in, align 4, !dbg !142, !tbaa !146, !alias.scope !149, !noalias !152 %26 = and i32 %unbox20, -3, !dbg !145 %27 = icmp eq i32 %26, 0, !dbg !145 br i1 %27, label %L81, label %L76, !dbg !132
L76: ; preds = %L68 %memcpy_refined_src = getelementptr inbounds [1 x [1 x double]], [1 x [1 x double]] addrspace(11) %1, i64 0, i64 0, i64 0, !dbg !157 %28 = load double, double addrspace(11) %memcpy_refined_src, align 8, !dbg !157, !tbaa !121, !alias.scope !158, !noalias !159 %box = call noalias nonnull dereferenceable(24) {} addrspace(10)* @julia.gc_alloc_obj({} nonnull %current_task1, i64 noundef 24, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124615414846864 to {}) to {} addrspace(10))) #115, !dbg !157 %29 = bitcast {} addrspace(10) %box to i8 addrspace(10), !dbg !157 %newstruct11.sroa.0.0..sroa_cast = bitcast {} addrspace(10) %box to double addrspace(10), !dbg !157 store double %28, double addrspace(10) %newstruct11.sroa.0.0..sroa_cast, align 8, !dbg !157, !tbaa !160, !alias.scope !161, !noalias !162 %newstruct11.sroa.2.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10) %29, i64 8, !dbg !157 %newstruct11.sroa.2.0..sroa_cast = bitcast i8 addrspace(10) %newstruct11.sroa.2.0..sroa_idx to double addrspace(10), !dbg !157 store double %28, double addrspace(10) %newstruct11.sroa.2.0..sroa_cast, align 8, !dbg !157, !tbaa !160, !alias.scope !161, !noalias !162 %newstruct11.sroa.3.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10) %29, i64 16, !dbg !157 %newstruct11.sroa.3.0..sroa_cast = bitcast i8 addrspace(10) %newstruct11.sroa.3.0..sroa_idx to double addrspace(10), !dbg !157 store double %0, double addrspace(10) %newstruct11.sroa.3.0..sroa_cast, align 8, !dbg !157, !tbaa !160, !alias.scope !161, !noalias !162 %30 = call nonnull {} addrspace(10) ({} addrspace(10) ({} addrspace(10), {} addrspace(10), i32, {} addrspace(10)), {} addrspace(10), {} addrspace(10), ...) @julia.call2({} addrspace(10) ({} addrspace(10), {} addrspace(10)*, i32, {} addrspace(10)) noundef nonnull @ijl_invoke, {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124615444310192 to {}) to {} addrspace(10)), {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616757572208 to {}) to {} addrspace(10)), {} addrspace(10) nonnull %box, {} addrspace(10) addrspacecast ({} inttoptr (i64 124616756638464 to {}) to {} addrspace(10))), !dbg !157 %31 = addrspacecast {} addrspace(10) %30 to [1 x [1 x double]] addrspace(11), !dbg !165 %32 = bitcast {} addrspace(10) %30 to [1 x [1 x double]] addrspace(10) br label %L81
L81: ; preds = %L68, %L76 %nodecayed..pn = phi [1 x [1 x double]] addrspace(10) %nodecayedoff..pn = phi i64 %.pn = phi [1 x [1 x double]] addrspace(11) [ %31, %L76 ], [ %1, %L68 ] %.sroa.084.0.in = getelementptr inbounds [1 x [1 x double]], [1 x [1 x double]] addrspace(11) %.pn, i64 0, i64 0, i64 0 %.sroa.084.0 = load double, double addrspace(11) %.sroa.084.0.in, align 8, !tbaa !160, !alias.scope !169, !noalias !170 %unbox10.fca.0.0.insert = insertvalue [1 x [1 x double]] poison, double %.sroa.084.0, 0, 0, !dbg !134 ret [1 x [1 x double]] %unbox10.fca.0.0.insert, !dbg !134
post_box_union: ; preds = %top call void @ijl_type_error(i8 noundef getelementptr inbounds ([3 x i8], [3 x i8] @_j_str4, i32 0, i32 0), {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616826846592 to {}) to {} addrspace(10)), {} addrspace(12) noundef addrspacecast ({} inttoptr (i64 124616754844928 to {}) to {} addrspace(12))) #116, !dbg !137 unreachable, !dbg !137
pass: ; preds = %is48, %L15 %unbox4.ph = phi i8 [ %12, %L15 ], [ %phi.cast97, %is48 ] %.not93 = icmp eq i8 %unbox4.ph, 0, !dbg !137 br i1 %.not93, label %L26, label %L51, !dbg !137
post_box_union32: ; preds = %L26 call void @ijl_type_error(i8 noundef getelementptr inbounds ([3 x i8], [3 x i8] @_j_str4, i32 0, i32 0), {} addrspace(10) noundef addrspacecast ({} inttoptr (i64 124616826846592 to {}) to {} addrspace(10)), {} addrspace(12) noundef addrspacecast ({} inttoptr (i64 124616754844928 to {}) to {} addrspace(12))) #116, !dbg !137 unreachable, !dbg !137
pass37: ; preds = %is, %L33 %unbox38.ph = phi i8 [ %phi.cast, %is ], [ %17, %L33 ] %.not95 = icmp eq i8 %unbox38.ph, 0, !dbg !137 br i1 %.not95, label %L44, label %L51, !dbg !137
is: ; preds = %L31 %33 = bitcast {} addrspace(10) %13 to i32 addrspace(10), !dbg !171 %unbox43 = load i32, i32 addrspace(10)* %33, align 4, !dbg !171, !tbaa !146, !alias.scope !149, !noalias !152 %34 = icmp eq i32 %unbox43, 0, !dbg !171 %phi.cast = zext i1 %34 to i8, !dbg !171 br label %pass37, !dbg !171
is48: ; preds = %L13 %35 = bitcast {} addrspace(10) %8 to i32 addrspace(10), !dbg !171 %unbox50 = load i32, i32 addrspace(10)* %35, align 4, !dbg !171, !tbaa !146, !alias.scope !149, !noalias !152 %36 = icmp eq i32 %unbox50, 0, !dbg !171 %phi.cast97 = zext i1 %36 to i8, !dbg !171 br label %pass, !dbg !171 }
Could not analyze garbage collection behavior of inst: %.pn = phi [1 x [1 x double]] addrspace(11) [ %31, %L76 ], [ %1, %L68 ] v0: [1 x [1 x double]] addrspace(11) %1 v: [1 x [1 x double]] addrspace(11)* %1 offset: i64 0 hasload: false
</p>
</details>
- Is there a way for Enzyme to force specialization of every method for functions that have custom rules?
A quick look at the devdocs suggests something like this should be possible using the method's generator
field: https://docs.julialang.org/en/v1/devdocs/ast/#ast-lowered-method. Seems like this would be beneficial even for cases that don't currently error, but needlessly end up in runtime handlers.
2. The MWE still fails, now with an error that looks related to compiling the reverse rule.
This is even true without the kwarg usage that triggered the original error. That is, with QuadGK.jl patched as described in the 2nd to last comment, even the simplest QuadGK call, like the example below, fails with the same error:
using Enzyme, QuadGK
constantintegral(a) = first(quadgk(_ -> a, -1.0, 1.0))
@show autodiff(Reverse, constantintegral, Active, Active(1.0))
OK, here's one way to get around the second error: add @inline
to the call to do_quadgk
. So to get around both issues, apply the following diff to QuadGK.jl:
diff --git a/src/api.jl b/src/api.jl
index b375260..875d3c2 100644
--- a/src/api.jl
+++ b/src/api.jl
@@ -77,10 +77,10 @@ derivatives of the approximate integral).
quadgk(f, segs...; kws...) =
quadgk(f, promote(segs...)...; kws...)
-function quadgk(f, segs::T...;
- atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing, eval_segbuf=nothing) where {T}
+function quadgk(f::F, segs::T...;
+ atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing, eval_segbuf=nothing) where {F,T}
handle_infinities(f, segs) do f, s, _
- do_quadgk(f, s, order, atol, rtol, maxevals, norm, segbuf, eval_segbuf)
+ @inline do_quadgk(f, s, order, atol, rtol, maxevals, norm, segbuf, eval_segbuf)
end
end
The @inline
macro can also be placed at the call to handle_infinities
; either works.
So to reiterate:
::F
to force specialization on the passed function. For a minimal reproducer without QuadGK, this is all you need (I suppose I haven't provided such a reproducer, but all you need is an inner function that takes both a function arg and a non-active but not unused float kwarg, and has a custom rule).::F
breaks every use of Enzyme, both those that worked before and those that didn't. The new error is shown 3 comments up the thread.@inline
as shown in the above diff. Now autodiff
works for quadgk
with or without the kwarg.Happy to have found a workaround, but it's a bit unfortunate that it requires patching a well-tested and type-stable implementation that isn't doing anything particularly exotic. Any hope of Enzyme handling these cases in the future?
What also works is to add @noinline
on quadgk
itself, as in the diff below. So maybe the important thing is that quadgk
does not get inlined, and the reason @inline do_quadgk
also works is that it makes quadgk
big enough to avoid automatic inlining.
diff --git a/src/api.jl b/src/api.jl
index b375260..17de777 100644
--- a/src/api.jl
+++ b/src/api.jl
@@ -77,8 +77,8 @@ derivatives of the approximate integral).
quadgk(f, segs...; kws...) =
quadgk(f, promote(segs...)...; kws...)
-function quadgk(f, segs::T...;
- atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing, eval_segbuf=nothing) where {T}
+@noinline function quadgk(f::F, segs::T...;
+ atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing, eval_segbuf=nothing) where {F,T}
handle_infinities(f, segs) do f, s, _
do_quadgk(f, s, order, atol, rtol, maxevals, norm, segbuf, eval_segbuf)
end
Since this meandering thread revealed two distinct errors, I'll close this issue and open two new ones with sharper focus and more accurate titles.
MWE:
Output:
It's unclear to me whether this is specific to the implementation of QuadGK's custom rule or is an issue with the custom rule machinery in general, but it looked to me like it might be the latter so I'm filing here.