A common interface for quadrature and numerical integration for the SciML scientific machine learning organization
ForwardDiff differentiates through solver #251

Open itsdfish opened 3 months ago

itsdfish commented 3 months ago

Describe the bug 🐞

ForwardDiff.jl differentiates through integral solver in the example below.

Expected behavior

In the example below, Integrals.jl should compute the integral without reaching the QuadGKJL algorithm. Instead, it hits the algorithm at kronrod where it returns a method error. Note that this example works correctly for Zygote.

Minimal Reproducible Example πŸ‘‡

using BenchmarkTools
using Distributions
using Integrals
using ForwardDiff

integrand(x::T1, p::T2) where {T1<:Real,T2<:Real} = pdf(Normal(p, x), one(promote_type(T1, T2)))

function f(Θ)
    domain = (Θ[2], Θ[3])
    p = Θ[1]
    prob = IntegralProblem(integrand, domain, p)
    sol = solve(prob, QuadGKJL(); reltol=1e-3, abstol=1e-3)
    return sol.u

Θ = [0.0, 0.0, 2.0]
ForwardDiff.gradient(f, Θ)[1]

Error & Stacktrace ⚠️

ERROR: MethodError: no method matching kronrod(::Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 3}}, ::Int64)

Closest candidates are:
  kronrod(::Any, ::Integer, ::Real, ::Real; rtol, quad)
   @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/weightedgauss.jl:90
  kronrod(::Type{T}, ::Integer) where T<:AbstractFloat
   @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/gausskronrod.jl:316
  kronrod(::AbstractMatrix{<:Real}, ::Integer, ::Real, ::Pair{<:Tuple{Real, Real}, <:Tuple{Real, Real}})
   @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/gausskronrod.jl:390

  [1] macro expansion
    @ ~/.julia/packages/QuadGK/OtnWt/src/gausskronrod.jl:564 [inlined]
  [2] _cachedrule
    @ ~/.julia/packages/QuadGK/OtnWt/src/gausskronrod.jl:564 [inlined]
  [3] cachedrule
    @ ~/.julia/packages/QuadGK/OtnWt/src/gausskronrod.jl:569 [inlined]
  [4] do_quadgk(f::Integrals.var"#53#59"{…}, s::Tuple{…}, n::Int64, atol::Float64, rtol::Float64, maxevals::Int64, nrm::typeof(LinearAlgebra.norm), segbuf::Vector{…})
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:7
  [5] #50
    @ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:253 [inlined]
  [6] handle_infinities(workfunc::QuadGK.var"#50#51"{…}, f::Integrals.var"#53#59"{…}, s::Tuple{…})
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:145
  [7] #quadgk#49
    @ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:252 [inlined]
  [8] quadgk
    @ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:250 [inlined]
  [9] #__solvebp_call#47
    @ ~/.julia/packages/Integrals/tvunm/src/Integrals.jl:142 [inlined]
 [10] __solvebp_call
    @ ~/.julia/packages/Integrals/tvunm/src/Integrals.jl:88 [inlined]
 [11] #__solvebp#1
    @ ~/.julia/packages/Integrals/tvunm/ext/IntegralsForwardDiffExt.jl:27 [inlined]
 [12] __solvebp
    @ ~/.julia/packages/Integrals/tvunm/ext/IntegralsForwardDiffExt.jl:7 [inlined]
 [13] solve!(cache::Integrals.IntegralCache{…})
    @ Integrals ~/.julia/packages/Integrals/tvunm/src/common.jl:105
 [14] solve(prob::IntegralProblem{…}, alg::QuadGKJL{…}; kwargs::@Kwargs{…})
    @ Integrals ~/.julia/packages/Integrals/tvunm/src/common.jl:101
 [15] f(Θ::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 3}})
    @ Main ~/.julia/dev/sandbox/DDM/hcubature_turing/hcubature_turing.jl:102
 [16] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
 [17] vector_mode_gradient(f::typeof(f), x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:89
 [18] gradient(f::Function, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…}, ::Val{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:19
 [19] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{…}, Float64, 3, Vector{…}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
 [20] top-level scope
    @ ~/.julia/dev/sandbox/DDM/hcubature_turing/hcubature_turing.jl:108
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 Γ— Intel(R) Core(TM) i7-4790K CPU @ 4.00GHz
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, haswell)
Threads: 4 default, 0 interactive, 2 GC (on 8 virtual cores)
  JULIA_CMDSTAN_HOME = /home/dfish/cmdstan

Additional context

This error was discovered by @gdalle. Please ping him for technical details, as AD is not my expertise.

He notes that the issue might be a missing rule for integral bounds:

lxvm commented 3 months ago

Right, at the moment we are missing a rule for differentiation with respect to the bounds. We want to implement the Leibniz integral rule of calculus, but there are many ways to do this, i.e. whether you discretize then differentiate or vice versa (this affects whether you do an integral over the same domain or integrals over the boundaries -- I think the former is preferable but there are also arguments for the latter). Adding a rule will be straightforward, I hope, but ultimately the user will want to have control over the differentiation algorithm, so I also marked this as a feature request.

We already have a "broken" test case for this specific issue: