f0uriest / quadax

Numerical quadrature with JAX
MIT License
36 stars 1 forks source link

weighted quadrature #5

Open twhentschel opened 5 months ago

twhentschel commented 5 months ago

I use weighted quadrature in my research, and I was thinking of using quadax to replace scipy.integrate.quad in my work. I'd be interested in contributing a feature that handles weighted integration. Would you be open to a PR like that?

f0uriest commented 5 months ago

Yes definitely! I haven't had the time to dig into how it's done exactly but if you want to take a stab at it that would certainly be welcome.

twhentschel commented 4 months ago

Okay, I've been looking at weighted quadrature schemes and how they are implemented in quadpack. In quadpack, it looks like for each type of weighting function there is a different adaptive interval bisection scheme to be used (i.e. by calling a different fixed-order rule depending if the new bisected interval has the singularity or not, and in some instances using both clenshaw-curtis and gauss-kronrod for different subregions). This is in contrast with this library which has the main adaptive_quadrature function that takes care of this adaptive portion of the different quadrature methods for one rule. Currently, I've been trying to see how I can write a weighted quadrature method without having to right a specialized adaptive_* routine like its done in quadpack! So it has been taking a while to get this up and running. The good news is I think I have a general understanding of how some of the weighting schemes in quadpack work -- it has just been a matter of translating this to work with quadax.

twhentschel commented 4 months ago

Okay, here's my tentative plan for now for the specific (algebraic-logarithmic) weight function $w(x) = (x - c_1)^\alpha * (c_2 - x)^\beta \log(x - c_1) \log(c_2 - x)$ where $c_1 < c_2$ are the locations of the singular points of this function:

Let's say we want to integrate $f(x) * w(x)$ from a to b, where $f(x)$ is some easy-to-integrate function and $w(x)$ is the weight function above, then if $c_1 == a$ or $c_2 == b$ (but not both cases), we can use a modified Clenshaw-Curtis (CC) fixed-quadrature rule where the weight function will modify the weightings of the basic CC approach (specifically with modified Chebyshev moments). For now, let's just assume that we are dealing with $c_1 == a$.

However, in the h-adaptive scheme, when we bisect the interval $[a, b]$, the new, second interval ($[(a + b)/2, b]$) will not contain the singular point $a$ and then we can't use the modified CC approach. Instead, we need to fall back to the basic CC method for this well-behaved interval (well-behaved because it no longer has a singular point).

So we need a rule that either does the modified CC method or the basic CC method depending on the interval its called on. This is what I envision for the adaptive method:

def quadcc_alglogweight(fun, interval, args=(), weightargs=None, full_output=False, epsabs=None, epsrel=None,max_ninter=50, order=32, norm=jnp.inf):
    # compute modified Chebyshev moments based on weightargs
    chebmom = # do something here

    def weightrule(fun, a, b,  args, norm, n):
        return fixed_quadcc_alglogweight(
            fun,
            a,
            b,
            args,
            norm,
            n,
            weightargs=weightargs,
            chebmom=chebmom,
        )

    def defaultrule(fun, a, b, args, norm, n):
        fun = lambda x, args: fun(x, args) * alglogweightfn(x, **weightargs)
        return fixed_quadcc(fun, a, b, args, norm, n)

    @functools.partial(jax.jit, static_argnums=(0, 4, 5))
    def rule(fun, a, b, args, norm, n):
        # rule switches depending on interval
        return jax.lax.cond(
            weightargs["singularity"] == a or b, weightrule, defaultrule, operand=(fun, a, b, args, norm, n)
        )

    y, info = adaptive_quadrature(
        rule,
        fun,
        interval,
        args,
        full_output,
        epsabs,
        epsrel,
        max_ninter,
        n=order,
        norm=norm,
    )
    info = QuadratureInfo(info.err, info.neval * order, info.status, info.info)
    return y, info

I'm not sure if this will work/be jitted properly etc., but I tried to avoid writing my own adaptive_quadrature routine. If you see any obvious limitation to this idea I'd love to hear them, otherwise I'll try writing it up. If it fails then, I'll probably have to write a specific adaptive_quadrature to handle weighted integration.

f0uriest commented 4 months ago

I think the basic approach makes sense, but a few comments:

twhentschel commented 3 months ago

Does your first point involve modifying the existing adaptive_quadrature method or writing a similar but new method for weighted functions? Edit: My concern this that different weight function will have different switching conditions, i.e. for the algebraic-logarithmic weight function, we need to check if there is a singularity at the endpoints

if a1 == weight_func.singularity1:
    # use first special rule
elif b1 == weight_func.singulatity2
    # use second special rule
else:
    # use regular rule

Or for the Cauchy weight function, we need to check if the singularity is contained within the interval

if a1 < weight_func.singularity < b1:
    # use special rule
else:
    # use regular rule

I'm not sure how to make a general adaptive_quadrature method that can decide which switching method to use based on which weight function we are considering. On the other hand, copy-and-pasting the current adaptive_quadrature method and slightly modifying parts of it for specific weight functions doesn't seem like a great option to me. Thoughts? :End edit

I really like your second point of using a class instead of passing around a dict of weight function parameters etc.! I definitely think that's the way to go.

For your last question, for the algebraic-logarithmic weight function, quadpack/scipy (which I'm treating as the standard) uses slightly different rules depending on if whether the singularity is at a or b (not both, but practically you just split the interval in half so each half only has one singular end point). I haven't looked that much into the other weights functions (cauchy or sin/cos weights) yet.

f0uriest commented 3 months ago

Ah I see what you mean. What if we made the check a method of the weight function? Like:

if weight_function is not None and weight_function.special_interval(a,b):
    foo = weight_function.integrate(a,b)
else:
    foo = base_rule.integrate(a,b)

I started to work up some classes for the basic integration rules here: https://github.com/f0uriest/quadax/pull/6 you should be able to use something similar for the weight functions.

twhentschel commented 3 months ago

Good idea! I'll start building off of the classes you've introduce.

twhentschel commented 2 months ago

Sorry for the delay, just wanted to give a short update: I have a version of the code that implements weighted quadrature (for just one specific weighting function for now), and next I want to implement some test integrals to make sure it works as expected.