JuliaMath / QuadGK.jl

adaptive 1d numerical Gauss–Kronrod integration in Julia
MIT License
272 stars 37 forks source link

Enzyme ext does not support Duplicated/MixedDuplicated closure as integrand #119

Open danielwe opened 1 month ago

danielwe commented 1 month ago

The Enzyme extension only implements reverse rules for integrands whose activity is Union{Const,Active}. In my problem, the integrand wraps an ApproxFun.jl Chebyshev interpolant specified by a Const function space and Duplicated coefficients, hence the closure ends up having MixedDuplicated activity, so Enzyme throws an error when trying to differentiate the QuadGK call. Support for this would be hugely appreciated!

cc @wsmoses

Brb with an MWE/test case

danielwe commented 1 month ago

MWE:

using ApproxFun, Enzyme, QuadGK

function chebyshevintegral(domain, coeffs)
    f = Fun(Chebyshev(domain), coeffs)
    return first(quadgk(f, endpoints(domain)...))
end

coeffs = [1.0]
domain = -1.0..1.0
@show chebyshevintegral(domain, coeffs)

dcoeffs = make_zero(coeffs)
autodiff(Reverse, chebyshevintegral, Active, Const(domain), Duplicated(coeffs, dcoeffs))
@show dcoeffs

Output:

julia> include("quadgkmixed.jl");
chebyshevintegral(domain, coeffs) = 2.0
ERROR: LoadError: MethodError: no method matching reverse(::EnzymeCore.EnzymeRules.ConfigWidth{…}, ::Const{…}, ::Active{…}, ::Tuple{…}, ::MixedDuplicated{…}, ::Active{…}, ::Active{…})

Closest candidates are:
  reverse(::Any, ::Const{typeof(quadgk)}, ::Active, ::Any, ::Union{Active, Const}, ::Annotation{T}...; kws...) where T
   @ QuadGKEnzymeExt ~/.julia/packages/QuadGK/rjZYB/ext/QuadGKEnzymeExt.jl:87
  reverse(::Any, ::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, ::Any, ::BodyTy, ::Any, ::Annotation...) where {BodyTy, N}
   @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/internal_rules.jl:296
  reverse(::Any, ::Const{Type{BigFloat}}, ::Type{<:Union{BatchDuplicated, BatchDuplicatedNoNeed, Duplicated, DuplicatedNoNeed}}, ::Any, ::Any...)
   @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/internal_rules.jl:977
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
  [2] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
  [3] AdjointThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6677 [inlined]
  [4] runtime_generic_rev(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, tape::Enzyme.Compiler.Tape{…}, f::typeof(quadgk), df::Nothing, primal_1::Fun{…}, shadow_1_1::Base.RefValue{…}, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/TiboG/src/rules/jitrules.jl:468
  [5] chebyshevintegral
    @ ~/issues/quadgkmixed.jl:5 [inlined]
  [6] chebyshevintegral
    @ ~/issues/quadgkmixed.jl:0 [inlined]
  [7] diffejulia_chebyshevintegral_9069_inner_1wrap
    @ ~/issues/quadgkmixed.jl:0
  [8] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
  [9] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
 [10] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6671 [inlined]
 [11] autodiff
    @ ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:320 [inlined]
 [12] autodiff(::ReverseMode{…}, ::typeof(chebyshevintegral), ::Type{…}, ::Const{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:332
 [13] top-level scope
    @ ~/issues/quadgkmixed.jl:13
 [14] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [15] top-level scope
    @ REPL[3]:1
in expression starting at /home/daniel/issues/quadgkmixed.jl:13
Some type information was truncated. Use `show(err)` to see complete types.

For the record, Enzyme does support differentiating through the ApproxFun.jl evaluation, as you can confirm by changing the return to something like f(0.0) instead of the quadgk call.

stevengj commented 1 month ago

As a workaround, can you wrap it in a closure x -> f(x)?

stevengj commented 1 month ago

cc @wsmoses

danielwe commented 1 month ago

can you wrap it in a closure x -> f(x)?

That doesn't help, unfortunately, you still end up with mixed activity. My real use case has a closure wrapping another function around the interpolant; calling the Fun object directly is just a consequence of the maximally stripped-down example.

danielwe commented 1 month ago

I'm trying to produce a workaround doing a more barebones Chebyshev polynomial evaluation using only the coefficients, no function space/domain object. But even if I get rid of MixedDuplicated, it will still be Duplicated because the coefficients are a mutable vector, and that's not yet supported either. (Editing the title to reflect this.)

stevengj commented 1 month ago

The Duplicated (and MixedDuplicated) cases are blocked by https://github.com/EnzymeAD/Enzyme.jl/issues/1692, as I understand it: https://github.com/JuliaMath/QuadGK.jl/blob/ce727e15f76df016ee2db9819fab0b4a7c6117fe/test/runtests.jl#L475-L476

danielwe commented 1 month ago

My case only has Duplicated/MixedDuplicated in the args, not the return. Unless there's some specific issue with the ClosureVector construction, I think it should be possible to get Duplicated arg/Active return working with the current Enzyme, that's the standard reverse mode configuration used everywhere. Not sure about MixedDuplicated arg/Active return; MixedDuplicated clearly exists but it's usage is undocumented, so only @wsmoses would know.

I'll take a stab at the Duplicated arg/Active return case today

wsmoses commented 1 month ago

@danielwe untested, but does something like https://github.com/JuliaMath/QuadGK.jl/pull/120 work?

danielwe commented 1 month ago

Here's an MWE that avoids ApproxFun, to avoid possible interference from things in that package that would need their own custom rules:

using Enzyme, QuadGK

function polyintegral(coeffs, scale)
    f(x) = scale * evalpoly(x, coeffs)
    return first(quadgk(f, -1.0, 1.0))
end

coeffs = [1.0]
scale = 1.0
@show polyintegral(coeffs, scale)

dcoeffs = make_zero(coeffs)
autodiff(Reverse, polyintegral, Active, Duplicated(coeffs, dcoeffs), Const(scale))
@show dcoeffs

Removing all references to scale in that code gives an MWE for the related case of a Duplicated rather than MixedDuplicated closure.

EDIT: Can also use Active(scale).