EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
437 stars 62 forks source link

Cannot deduce type for QuadGK #1599

Closed mhauru closed 1 month ago

mhauru commented 2 months ago

MWE:

module MWE

import QuadGK.quadgk
using Enzyme

function g(x)
    function g_one(y)
        return exp(-0.5 * y * (2 * x + y) * (2 * x + y))
    end
    return quadgk(g_one, 0.0, Inf)[1]
end

Enzyme.autodiff(Enzyme.Reverse, g, Duplicated(0.5, 0.5))

end

Output: https://gist.githubusercontent.com/mhauru/25e6fa41671b94cb392b9df01dd8f821/raw/3d175b8d8a08216df8b2d6696b0635b430f3ecc0/QuadGK_cannot_deduce_type

wsmoses commented 1 month ago

Offhand I don't know what the correct reverse rule is but here's some infra for this to be added:

using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f::Const, segs::Annotation{T}...; kws...) where {RT, T}
    prims = map(x->x.val, segs)
    res = ofunc.val(f.val, prims; kws...)

    retres = if EnzymeRules.needs_primal(config)
        res
    else
        nothing
    end

    dres = if EnzymeRules.width(config) == 1
        zero(res)
    else
        ntuple(Val(EnzymeRules.width(config))) do i
            Base.@_inline_meta
            zero(res)
        end
    end

    cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
        dres
    else
        nothing
    end

    return EnzymeCore.EnzymeRules.AugmentedReturn{
        EnzymeCore.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
        EnzymeCore.EnzymeRules.needs_shadow(config) ? (EnzymeCore.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeCore.EnzymeRules.width(config), eltype(RT)}) : Nothing,
        typeof(cache)
    }(retres, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Const, segs::Annotation{T}...; kws...) where {RT, T}
    TODO
end

@stevengj since it looks like you're a code owner, do you have any insights and/or would want to work through the rule together?

stevengj commented 1 month ago

Is the differentiate-then-discretize approximation acceptable (i.e. neglecting the error in the quadrature rule)? Then the Jacobian with respect to a parameter of an integrand is just the integral of the Jacobian of the integrand. Or, in the case of a reverse rule, the integral of vJp of the integrand. I think this is the approach used by Integrals.jl. My student @lxvm worked on this, so cc'ing him.

If you want the exact "discretize-then-differentiate" Jacobian (including quadrature/discretization error), i.e. the exact derivative of the approximate integral up to roundoff error, then the simple approach won't work because quadgk is adaptive — it may use different quadrature points when integrating the vJp than it did for the original integrand.

One slick approach to "discretize-then-differentiate" would be to call quadgk with both the original integrand and the vJp integrand, combined into a vector of integrands (ideally something like an SVector so that you don't introduce new heap allocations — is there an SVector-like type with heterogeneous elements, for Cartesian-product vector spaces?). To ensure that it uses the same quadrature points as it does with the original integrand, you can pass a custom error norm function to quadgk that only takes the norm of the first element (i.e. so that the vJp part is ignored for convergence/refinement purposes). This would return both the integral and the vJp; hopefully the former could be re-used so that it isn't computed twice.

wsmoses commented 1 month ago

I think if we write a rule (which makes sense here numerically for the reasons of different quadrature points specified above anyways), we might as well go for the discretize-then-differentiate solution.

So one different notion here -- at least in Enzyme reverse mode, is that the original result is required to be computed in the forward pass [and optionally not ever computed if the compiler tells Enzyme it isn't needed and sets needs_primal to false], so I'm not sure how to fuse into one quadgk call. Moreover wouldn't it be more stable to have two distinct quadgk calls to pick the different points, or am I misunderstanding you and/or the magnitude of the performance implications

stevengj commented 1 month ago

(Similarly, for the derivative with respect to an endpoint of the integration domain, there is a simple differentiate-then-discretize rule using the fundamental theorem of calculus. The discretize-then-differentiate rule is more complicated, but can also be implemented by augmenting the integrand.)

Another question is whether you want something specific to QuadGK, or if you want a more generic method for Integrals.jl. Unfortunately, to support multiple backends in Integrals.jl it may be harder, since not all solvers support specifying a custom error norm.)

stevengj commented 1 month ago

Moreover wouldn't it be more stable to have two distinct quadgk calls to pick the different points

It can't pick the points until it is actually computing the integral, since it adaptively looks at error estimations from the estimated integral so far.

wsmoses commented 1 month ago

Unless the quadgk rule would be obnoxious, I think it makes sense to add quadgk, and possibly also integrals later down the line. Reason being is that we've already seen various packages which use quadgk not via integrals, hitting issues as a result.

We have something similar where we internally support the various julia solvers in Enzyme directly, and also have a rule within sciml solver packages

stevengj commented 1 month ago

If you have to have two separate calls for the original result and the vJp, then you either have to pay the price for estimating the integrand twice, or add a new API to QuadGK that allows it to cache all of the integrand points and weights, or accept the approximation of the differentiate-then-discretize approach.

Actually, it's not crazy to cache all of the integrand points and weights, since effectively QuadGK already saves this information (it builds up a heap of subintervals and knows the quadrature rule for each subinterval). In fact, QuadGK's segbuf API almost gives the caller access to this, so with only slight modifications it might be possible to re-use this for a reverse rule.

wsmoses commented 1 month ago

Well we can indeed save arbitrary data so this should be doable if quadgk had the API

stevengj commented 1 month ago

I implemented a new API for this in https://github.com/JuliaMath/QuadGK.jl/pull/108

Now, if you have a function call (I, E) = quadgk(...), you can replace it with:

I, E, segbuf = quadgk_segbuf(...)

(There is generally no extra cost to this, since QuadGK computes the segbuf internally anyway.) Then, on a subsequent call to quadgk, even with a different integrand vJp, you can pass:

quadgk(vJp, ...; ..., eval_segbuf=segbuf, maxevals=0)

and it will evaluate the new integrand using exactly the same quadrature rule (the same subintervals).

This should make it fairly easy and efficient to implement the exact derivative (discretize-then-differentiate) of the integral estimate I with respect to parameters of the integrand, or with respect to the endpoints.

(Is there an optional dependency that we can add to QuadGK.jl, analogous to ChainRulesCore.jl, in order to add the AD rules directly to the package?)

wsmoses commented 1 month ago

Yeah we can just add a dependency to EnzymeCore

On Sun, Jul 21, 2024 at 3:30 PM Steven G. Johnson @.***> wrote:

I implemented a new API for this in JuliaMath/QuadGK.jl#108 https://github.com/JuliaMath/QuadGK.jl/pull/108

Now, if you have a function call (I, E) = quadgk(...), you can replace it with:

I, E, segbuf = quadgk(...)

(There is generally no extra cost to this, since QuadGK computes the segbuf internally anyway.) Then, on a subsequent call to quadgk, even with a different integrand vJp, you can pass:

quadgk(vJp, ...; ..., eval_segbuf=segbuf, maxevals=0)

and it will evaluate the new integrand using exactly the same quadrature rule (the same subintervals).

This should make it fairly easy and efficient to implement the exact derivative (discretize-then-differentiate) of the integral estimate I with respect to parameters of the integrand, or with respect to the endpoints.

(Is there an optional dependency that we can add to QuadGK.jl, analogous to ChainRulesCore.jl, in order to add the AD rules directly to the package?)

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1599#issuecomment-2241749607, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXGK47YYLGKCMG3KALTZNQD37AVCNFSM6AAAAABKHJYLOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBRG42DSNRQG4 . You are receiving this because you commented.Message ID: <EnzymeAD/Enzyme. @.***>

lxvm commented 1 month ago

I think @stevengj has the right idea with adding a segbuf API to QuadGK.jl because with a fixed quadrature rule (in this case a composite and adaptive quadrature rule) the discretize-then-differentiate rules are very simple. Although calculus offers very convenient differentiate-then-discretize rules, it is not obvious how to automatically select absolute error tolerances for the adaptive integration of the vJp or Jvp, which as functions may behave differently from the original integrand, and if you think of them as dimensionful quantities they also have different units (this point is related to Steven's earlier point about choosing a norm). Fortunately, relative error tolerances don't have this issue, but often an absolute tolerance is still important to get an efficient quadrature rule. That being said, I think Integrals.jl will first go down the differentiate-then-discretize route because it is compatible with solvers that accept vector-valued integrands. (Already this new API for QuadGK.jl is innovative since most adaptive integrators don't return the quadrature rule that they construct internally. I recently added a similar feature to HChebInterp.jl so that a user can construct an adaptive interpolant starting from a grid used by another available interpolant.)

So is the following the current proposed scheme?:

stevengj commented 1 month ago

Call quadgk to integrate the original function and gradients with the eval_segbuf=segbuf and maxevals=0 keywords (AD used here)

My understanding is that we won't be using the original function on the second call, since we don't want to pay the price of integrating it again.

The proposal instead is to implement a custom reverse rule for quadgk (and friends) that transforms it to a call of quadgk_segbuf, saves the segbuf, and then in the pullback calls quadgk(..., eval_segbuf=segbuf, maxevals=0) with a new integrand (based on the vJp of the original integrand) to compute the vJp of the integral.

stevengj commented 1 month ago

In particular, suppose that the original integral is $$I(p) = \int_{a(p)}^{b(p)} f(x, p) dx = (b(p)-a(p))\int_0^1 f(x(t,p), p) dt$$ where $p$ is some vector of parameters affecting the integrand and/or endpoints, and $x(t,p) = a(p) + (b(p) - a(p)) t$ is just an affine change of variables. Here, I've pulled the $[a,b]$ domain out of the integral limits in order to avoid using the fundamental theorem of calculus, which is only approximate for quadrature rules.

Instead, if we evaluate $I(p)$ and then save the quadrature rule via segbuf, the formulas below become exact if we re-use the quadrature rule with eval_segbuf=segbuf, maxevals=0.

For a reverse rule, we want the vJp $v^T \frac{\partial I}{\partial p}$ for some vector $v$, with $\frac{\partial I}{\partial p}$ denoting the Jacobian. This is given by a new integral:

$$\begin{aligned} v^T \frac{\partial I}{\partial p} = &v^T\underbrace{\frac{ \partial \ln(b-a)}{\partial p}}_c I \ & {}+ \int_a^b \left[ v^T \frac{\partial f}{\partial p} + v^T \left(\frac{\partial a}{\partial p} + c(x-a)\right) \frac{\partial f}{\partial x}\right] dx \end{aligned}$$

in which, hopefully assuming I haven't made any algebra errors, you will get the exact result if you plug the new integrand into the old quadrature rule.

(quadgk also allows the caller to subdivide the integration domain into multiple intervals. This is more complicated to handle if you want the derivative with respect to the interval boundaries, but we could punt on those derivatives to start with.)

wsmoses commented 1 month ago

Sorry thinking about this more, I think we actually instead should probably elect the differentiation-then-discretize rather than the other way round. This is mostly a matter of convention -- are people calling the package expecting it to sum up with assumed point boundaries [which would imply the discretize is the right solution], or do they assume that it will generically integrate, with some possible error of integration [at which point the differentiate then discretize is correct].

Of course both of these have their relevant use cases, but thinking on it more differentiate first feels like the more reasonable default for users [who don't say intend to AD the relative error of the discretization itself].

Maybe we come up with an option for being able to specify the behavior and pick as a result

lxvm commented 1 month ago

My understanding is that we won't be using the original function on the second call, since we don't want to pay the price of integrating it again.

Although this sounds like an optimization, especially in reverse mode where you may have to cache temporary values of the integrand at all points of evaluation, it could be worthwhile. If I had to implement this in Zygote I would write something like this

function ChainRulesCore.rrule(::typeof(quadgk), f, a, b; norm=norm, order=7, kws...)
    I, E, segbuf = quadgk_segbuf(x -> f((b-a)*x/2+(a+b)/2), -1, 1; norm, order, kws...)
    x, w, wg = QuadGK.cachedrule(typeof((a+b)/2), order)
    I_ab, back = Zygote.pullback(f, a, b) do f, a, b
        s, c = (b-a)/2, (a+b)/2 
        sum(QuadGK.evalrule(f, seg.a*s+c, seg.b*s+c, x, w, wg, norm).I for seg in segbuf)
    end
    return (I_ab, E*(b-a)/2), (dI, _) -> (NoTangent(), back(dI)...)
end

This is given by a new integral:

Hopefully the rule I wrote above lets the AD framework perform the calculations you detailed

we could punt on those derivatives to start with

In the example above I already have. Also, if someone specifies breakpoints, isn't it usually because the integrand is non-differentiable at that point?

lxvm commented 1 month ago

Maybe we come up with an option for being able to specify the behavior and pick as a result

I think that would be good to build in from the beginning, since both behaviors may be desirable.

stevengj commented 1 month ago

If the quadrature is sufficiently converged that you can use differentiate-then-discretize, then you should also be able to use discretize-then-differentiate.

The discretize-then-differentiate approach requires more support from QuadGK, but it should be more efficient. Not only do you not have to worry about tolerances, as @lxvm points out, but it also requires fewer function evaluations (usually about half as many) since it skips the adaptive subdivision steps.

stevengj commented 1 month ago

Hopefully the rule I wrote above lets the AD framework perform the calculations you detailed

This rule won’t work for infinite limits, in-place integrands, and batched integrands, for example.

stevengj commented 1 month ago

Of course both of these have their relevant use cases, but thinking on it more differentiate first feels like the more reasonable default for users [who don't say intend to AD the relative error of the discretization itself].

Even if the user imagines that they are solving the problem exactly and is not thinking about the derivative of the error terms, it is often better to AD the discretization error too if it is practical to do so, e.g. if you use this inside an optimization algorithm (which will expect the derivative to predict the first-order change in the computed function).

The caveat is that, for adaptive quadrature, if they are updating the integrand parameters then the adaptive quadrature mesh will often change. This effectively introduces small discontinuities into the function that will screw up optimization if they become too large, whether you use differentiate-then-discretize or vice versa. So for adaptive algorithms, in practice, the user will have to ensure that it is sufficiently converged. (But at this point you can use either AD scheme, as I mentioned above, and discretize-then-differentiate can actually be more efficient in principle.)

stevengj commented 1 month ago

It might be worth sticking with the differentiate-then-discretize approach for differentiating with respect to the integration endpoints, however, as in that case the analytical rule is vastly cheaper and easier to compute, and also eliminates a lot of the complications that arise with the discretization-then-discretize approach for multi-point intervals.

So, maybe a hybrid scheme: differentiate-then-discretize for the endpoints, but discretize-then-differentiate for parameters of the integrand. This way we get the best of both worlds. (And the user shouldn't care if their integral is sufficiently converged.)

wsmoses commented 1 month ago

using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f::Const, segs::Annotation{T}...; kws...) where {RT, T}
    prims = map(x->x.val, segs)
    res = ofunc.val(f.val, prims; kws...)

    retres = if EnzymeRules.needs_primal(config)
        res
    else
        nothing
    end

    dres = if EnzymeRules.width(config) == 1
        zero(res)
    else
        ntuple(Val(EnzymeRules.width(config))) do i
            Base.@_inline_meta
            zero(res)
        end
    end

    cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
        dres
    else
        nothing
    end

    return EnzymeCore.EnzymeRules.AugmentedReturn{
        EnzymeCore.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
        EnzymeCore.EnzymeRules.needs_shadow(config) ? (EnzymeCore.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeCore.EnzymeRules.width(config), eltype(RT)}) : Nothing,
        typeof(cache)
    }(retres, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Const, segs::Annotation{T}...; kws...) where {RT, T}
    res = ofunc.val(EnzymeCore.autodiff(Reverse, f.val, segs...); kws...)
    ntuple(Val(length(segs))) do i
        Base.@_inline_meta
        if segs[i] isa Const
            nothing
        elseif EnzymeCore.EnzymeRules.width(config) == 1
            dres * res[i]
        else
            ntuple(Val(EnzymeCore.EnzymeRules.width(config))) do j
                Base.@_inline_meta
                dres * res[i][j]
            end
        end
    end
end
stevengj commented 1 month ago

Draft 2, missing ClosureVector implementation:

using Revise, Enzyme, QuadGK, EnzymeCore, LinearAlgebra

function EnzymeCore.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T<:Real}
    prims = map(x->x.val, segs)

    I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...)
    retres = if EnzymeRules.needs_primal(config)
        res = I, E
    else
        nothing
    end

    dres = if !EnzymeCore.EnzymeRules.needs_shadow(config)
        nothing
    elseif EnzymeRules.width(config) == 1
        zero.(res...)
    else
        ntuple(Val(EnzymeRules.width(config))) do i
            Base.@_inline_meta
            zero.(res...)
        end
    end

    cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
        dres
    else
        nothing
    end
    cache2 = segbuf, cache

    return EnzymeCore.EnzymeRules.AugmentedReturn{
        EnzymeCore.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
        EnzymeCore.EnzymeRules.needs_shadow(config) ? (EnzymeCore.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeCore.EnzymeRules.width(config), eltype(RT)}) : Nothing,
        typeof(cache2)
    }(retres, dres, cache2)
end

function call(f, x)
    f(x)
end

struct ClosureVector{F}
    f::F
end

function Base.:+(a::ClosureVector, b::ClosureVector)
    return a
    # throw(AssertionError("todo +"))
end

function Base.:-(a::ClosureVector, b::ClosureVector)
    return a+(-1*b)
end

function Base.:*(a::Number, b::ClosureVector)
    return b
    # throw(AssertionError("todo +"))
end

function Base.:*(a::ClosureVector, b::Number)
    return b*a
end

function EnzymeCore.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f, segs::Annotation{T}...; kws...) where {T<:Real}
    # res = ofunc.val(EnzymeCore.autodiff(Reverse, f.val, segs...); kws...)

    df = if f isa Const
        nothing
    else
        segbuf = cache[1]
        fwd, rev = EnzymeCore.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
        _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
            tape, prim, shad = fwd(Const(call), f, Const(x))
            drev = rev(Const(call), f, Const(x), dres.val[1], tape)
            return ClosureVector(drev[1][1])
        end
        _df.f
    end
    dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1])
    dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])
    return (df, # f
            dsegs1,
            ntuple(i -> nothing, Val(length(segs)-2))...,
            dsegsn)
end