JuliaMath / QuadGK.jl

adaptive 1d numerical Gauss–Kronrod integration in Julia
MIT License
252 stars 35 forks source link

Autodiff of `quadgk` #73

Open Vilin97 opened 1 year ago

Vilin97 commented 1 year ago

Thank you for a great package! I would like to use Zygote on quadgk but below is a self-contained example, where it does not work because quadgk calls setindex!. Can I do anything to work around this? I coded up a small example of differentiating quadgk of a constant function and it did work so it seems that there is hope for Zygote and QuadGK to play together nicely.

using Zygote, QuadGK

function F(c, x)
    if x < c
        return 0.
    elseif x < 1.
        return exp(c)*(x-c)^2/2
    elseif x < 1. + c
        return exp(c)*(1-c)*(2x - c - 1.)/2.
    elseif x < 2.
        return exp(c)*(4x - x^2 - 2c - 2)/2.
        return exp(c)*(1-c)

"pdf of X"
function f(c, x)
    if x < c || x > 2.
        return 0.
    elseif x < 1.
        return exp(c)*(x-c)
    elseif x < 1. + c
        return exp(c)*(1-c)
        return exp(c)*(2-x)

function D(c)

function W(c1, c2)
    val, acc = quadgk(t -> F(c1, t)*f(c2, t), c1, 2)

function prob_of_win(c1, c2)
    ((1-D(c1))*D(c2) + W(c2, c1))/(D(c1) + D(c2) - D(c1)*D(c2))

gradient(prob_of_win, 0.25, 0.25) # ERROR: Mutating arrays is not supported -- called setindex!(Vector{QuadGK.Segment{Float64, Float64, Float64}}, ...)
stevengj commented 1 year ago

Basically this needs to be done by writing ChainRules that tell Zygote how to differentiate integrals without differentiating through the (mutating) code.

Fortunately, this has already been done — if you use https://github.com/SciML/Integrals.jl, then it provides a wrapper around QuadGK that provides the appropriate ChainRules.

In the longer run it might be worth adding chain rules to QuadGK directly.

lxvm commented 10 months ago

Autodiff in Integrals.jl was recently fixed in this PR https://github.com/SciML/Integrals.jl/pull/175 and it should be available in its next release.