jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.4k stars 2.79k forks source link

Using `lpmn_values` for Legendre polynomials. #14101

Open BeeGass opened 1 year ago

BeeGass commented 1 year ago

I have been looking for a function that does something similar to that of scipy's scipy.special.eval_legendre function. I dont know the math well but I have been told that the function that is supported by jax jax.scipy.special.lpmn_values would be a good use by making the function's parameter m=0 and this solves for the polynomial.

Legendre function Wiki Page

The polynomial solutions when λ is an integer (denoted n), and μ = 0 are the Legendre polynomials Pn;

However I found that when I looked at the source code:

Found Here

if m != n:
    raise NotImplementedError('Computations for m!=n are not yet supported.')

It would be really great to be able to solve for the Legendre polynomial. I dont know the math well enough (yet) to implement this myself so im hoping a more experienced person might be able to help.

BeeGass commented 1 year ago

To be transparent, I used chatGPT to help with this. This seems to be a nice work around for the time being:

def legendre_recurrence(n, x, max_n):
    """
    Computes the Legendre polynomial of degree n at point x using the recurrence relation.

    Args:
    n: int, the degree of the Legendre polynomial.
    x: float, the point at which to evaluate the polynomial.
    max_n: int, the maximum degree of n in the batch.

    Returns:
    The value of the Legendre polynomial of degree n at point x.
    """
    # Initialize the array to store the Legendre polynomials for all degrees from 0 to max_n
    p = jnp.zeros((max_n + 1,) + x.shape)
    p = p.at[0].set(1.0)  # Set the 0th degree Legendre polynomial
    p = p.at[1].set(x)  # Set the 1st degree Legendre polynomial

    # Compute the Legendre polynomials for degrees 2 to max_n using the recurrence relation
    def body_fun(i, p):
        p_i = ((2 * i - 1) * x * p[i - 1] - (i - 1) * p[i - 2]) / i
        return p.at[i].set(p_i)

    p = jax.lax.fori_loop(2, max_n + 1, body_fun, p)

    return p[n]
def eval_legendre(n, x, out=None):
    """
    Evaluates the Legendre polynomials of degrees specified in the input array n at the points specified in the input array x.

    Args:
    n: array-like, the degrees of the Legendre polynomials.
    x: array-like, the points at which to evaluate the polynomials.
    out: optional, an output array to store the results.

    Returns:
    An array containing the Legendre polynomial values of the specified degrees at the specified points.
    """
    n = jnp.asarray(n)
    x = jnp.asarray(x)
    max_n = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, max_n))(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.special.eval_legendre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, max_n))(x)
        )(n)

    if out is not None:
        out = jnp.asarray(out)
        out = jnp.copy(p, out=out)
        return out
    else:
        return p
def test_eval_legendre():
    n = np.array([0, 1, 2, 3])

    print(f"n = {n}")
    print(f"n shape = {n.shape}")

    x = np.linspace(-1, 1, n.shape[0])

    print(f"x = {x}")
    print(f"x shape = {x.shape}")

    y_pred = eval_legendre(n, x)
    y = ss.eval_legendre(n, x)

    print(f"y_pred = {y_pred}")
    print(f"y_pred shape = {y_pred.shape}")
    print(f"y = {y}")
    print(f"y shape = {y.shape}")

    assert np.allclose(y_pred, y, rtol=1e-5, atol=1e-8), "Results do not match"
    print("Results match")

output:

n = [0 1 2 3]
n shape = (4,)
x = [-1.         -0.33333333  0.33333333  1.        ]
x shape = (4,)
y_pred = [ 1.         -0.33333334 -0.3333333   1.        ]
y_pred shape = (4,)
y = [ 1.         -0.33333333 -0.33333333  1.        ]
y shape = (4,)
Results match
BeeGass commented 1 year ago

I also did this for evaluating Laguerre polynomials:

def genlaguerre_recurrence(n, alpha, x, max_n):
    """
    Computes the generalized Laguerre polynomial of degree n with parameter alpha at point x using the recurrence relation.

    Args:
    n: int, the degree of the generalized Laguerre polynomial.
    alpha: float, the parameter of the generalized Laguerre polynomial.
    x: float, the point at which to evaluate the polynomial.
    max_n: int, the maximum degree of n in the batch.

    Returns:
    The value of the generalized Laguerre polynomial of degree n with parameter alpha at point x.
    """
    # Initialize the array to store the generalized Laguerre polynomials for all degrees from 0 to max_n
    p = jnp.zeros((max_n + 1,) + x.shape)
    p = p.at[0].set(1.0)  # Set the 0th degree generalized Laguerre polynomial

    # Compute the generalized Laguerre polynomials for degrees 1 to max_n using the recurrence relation
    def body_fun(i, p):
        p_i = ((2 * i + alpha - 1 - x) * p[i - 1] - (i + alpha - 1) * p[i - 2]) / i
        return p.at[i].set(p_i)

    p = jax.lax.fori_loop(1, max_n + 1, body_fun, p)

    return p[n]
def eval_genlaguerre(n, alpha, x, out=None):
    """
    Evaluates the generalized Laguerre polynomials of degrees specified in the input array n with parameter alpha at the points specified in the input array x.

    Args:
    n: array-like, the degrees of the generalized Laguerre polynomials.
    alpha: float, the parameter of the generalized Laguerre polynomials.
    x: array-like, the points at which to evaluate the polynomials.
    out: optional, an output array to store the results.

    Returns:
    An array containing the generalized Laguerre polynomial values of the specified degrees with parameter alpha at the specified points.
    """
    n = jnp.asarray(n)
    x = jnp.asarray(x)
    max_n = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(
                lambda xi: genlaguerre_recurrence(ni, alpha, xi, max_n)
            )(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.signal.eval_genlaguerre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(
                lambda xi: genlaguerre_recurrence(ni, alpha, xi, max_n)
            )(x)
        )(n)

    if out is not None:
        out = jnp.asarray(out)
        out = jnp.copy(p, out=out)
        return out
    else:
        return p
def test_eval_genlaguerre():
    alpha = 2.0
    n = np.array([0, 1, 2, 3])

    print(f"n = {n}")
    print(f"n shape = {n.shape}")

    x = np.linspace(-1, 1, n.shape[0])

    print(f"x = {x}")
    print(f"x shape = {x.shape}")

    y_pred = eval_genlaguerre(n, alpha, x)
    y = jnp.array(ss.eval_genlaguerre(n, alpha, x))

    print(f"y_pred = {y_pred}")
    print(f"y_pred shape = {y_pred.shape}")
    print(f"y = {y}")
    print(f"y shape = {y.shape}")

    assert np.allclose(y_pred, y, rtol=1e-5, atol=1e-8), "Results do not match"
    print("Results match")

output:

n = [0 1 2 3]
n shape = (4,)
x = [-1.         -0.33333333  0.33333333  1.        ]
x shape = (4,)
y_pred = [1.        3.3333333 4.7222223 2.3333335]
y_pred shape = (4,)
y = [1.        3.3333333 4.7222223 2.3333333]
y shape = (4,)
Results match
BeeGass commented 1 year ago

I have since improved on this implementation a bit. It was numerically unstable and wasn't JIT'able. Here is what I have come up with. The test stays the same.

Legendre:

import jax
from jax import numpy as jnp
from jaxtyping import Array, Float, Int
from typing import Any, Callable, Mapping, Optional

def legendre_recurrence(
    n: Int[Array, "n"], x: Float[Array, "m"], n_max: Int[Array, ""]
) -> Float[Array, "n m"]:
    """
    Compute the Legendre polynomials up to degree n_max at a given point or array of points x.

    The function employs the recurrence relation for Legendre polynomials. The Legendre polynomials
    are orthogonal on the interval [-1,1] and are used in a wide array of scientific and mathematical applications.
    This function returns a series of Legendre polynomials evaluated at the point(s) x, up to the degree n_max.

    Args:
        n_max (int): The highest degree of Legendre polynomial to compute. Must be a non-negative integer.
        x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
                        point (float) or an array of points.

    Returns:
        jnp.ndarray: A sequence of Legendre polynomial values of shape (n_max+1,) + x.shape, evaluated at point(s) x.
                    The i-th entry of the output array corresponds to the Legendre polynomial of degree i.

    Notes:
        The first two Legendre polynomials are initialized as P_0(x) = 1 and P_1(x) = x. The subsequent polynomials
        are computed using the recurrence relation:
        P_{n+1}(x) = ((2n + 1) * x * P_n(x) - n * P_{n-1}(x)) / (n + 1).
    """

    p_init = jnp.zeros((2,) + x.shape)
    p_init = p_init.at[0].set(1.0)  # Set the 0th degree Legendre polynomial
    p_init = p_init.at[1].set(x)  # Set the 1st degree Legendre polynomial

    def body_fun(carry, _):
        i, (p_im1, p_i) = carry
        p_ip1 = ((2 * i + 1) * x * p_i - i * p_im1) / (i + 1)

        return ((i + 1).astype(int), (p_i, p_ip1)), p_ip1

    (_, (_, _)), p_n = jax.lax.scan(
        f=body_fun, init=(1, (p_init[0], p_init[1])), xs=(None), length=(n_max - 1)
    )
    p_n = jnp.concatenate((p_init, p_n), axis=0)

    return p_n[n]
def eval_legendre(n: Int[Array, "n"], x: Float[Array, "m"]) -> Float[Array, "n m"]:
    """
    Evaluate Legendre polynomials of specified degrees at provided point(s).

    This function makes use of a vectorized version of the Legendre polynomial recurrence relation to
    compute the necessary polynomials up to the maximum degree found in 'n'. It then selects and returns
    the values of the polynomials at the degrees specified in 'n' and evaluated at the points in 'x'.

    Parameters:
        n (jnp.ndarray): An array of integer degrees for which the Legendre polynomials are to be evaluated.
                        Each element must be a non-negative integer and the array can be of any shape.
        x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
                        point (float) or an array of points. The shape must be broadcastable to the shape of 'n'.

    Returns:
        jnp.ndarray: An array of Legendre polynomial values. The output has the same shape as 'n' and 'x' after broadcasting.
                    The i-th entry corresponds to the Legendre polynomial of degree 'n[i]' evaluated at point 'x[i]'.

    Notes:
        This function makes use of the vectorized map (vmap) functionality in JAX to efficiently compute and select
        the necessary Legendre polynomial values.
    """
    n = jnp.asarray(n)
    x = jnp.asarray(x)
    n_max = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, n_max))(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.special.eval_legendre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, n_max))(x)
        )(n)

    return jnp.squeeze(p)
BeeGass commented 1 year ago

Laguerre:

def genlaguerre_recurrence(
    n: Int[Array, "n"],
    alpha: Float[Array, ""],
    x: Float[Array, "m"],
    max_n: Int[Array, ""],
) -> Float[Array, "n m"]:
    """
    Computes the generalized Laguerre polynomial of degree n with parameter alpha at point x using the recurrence relation.

    Args:
        n: int, the degree of the generalized Laguerre polynomial.
        alpha: float, the parameter of the generalized Laguerre polynomial.
        x: float, the point at which to evaluate the polynomial.
        max_n: int, the maximum degree of n in the batch.

    Returns:
        The value of the generalized Laguerre polynomial of degree n with parameter alpha at point x.
    """
    # Initialize the array to store the generalized Laguerre polynomials for all degrees from 0 to max_n
    p = jnp.zeros((max_n + 1,) + x.shape)
    p = p.at[0].set(1.0)  # Set the 0th degree generalized Laguerre polynomial

    # Compute the generalized Laguerre polynomials for degrees 1 to max_n using the recurrence relation
    def body_fun(i, p):
        p_i = ((2 * i + alpha - 1 - x) * p[i - 1] - (i + alpha - 1) * p[i - 2]) / i
        return p.at[i].set(p_i)

    p = jax.lax.fori_loop(1, max_n + 1, body_fun, p)

    return p[n]
def eval_genlaguerre(
    n: Int[Array, "n"],
    alpha: Float[Array, ""],
    x: Float[Array, "m"],
    out: Float[Array, "n m"] = None,
) -> Float[Array, "n m"]:
    """
    Evaluates the generalized Laguerre polynomials of degrees specified in the input array n with parameter alpha at the points specified in the input array x.

    Args:
        n: array-like, the degrees of the generalized Laguerre polynomials.
        alpha: float, the parameter of the generalized Laguerre polynomials.
        x: array-like, the points at which to evaluate the polynomials.
        out: optional, an output array to store the results.

    Returns:
        An array containing the generalized Laguerre polynomial values of the specified degrees with parameter alpha at the specified points.
    """
    n = jnp.asarray(n)
    x = jnp.asarray(x)
    max_n = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(
                lambda xi: genlaguerre_recurrence(ni, alpha, xi, max_n)
            )(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.signal.eval_genlaguerre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(
                lambda xi: genlaguerre_recurrence(ni, alpha, xi, max_n)
            )(x)
        )(n)

    if out is not None:
        out = jnp.asarray(out)
        out = jnp.copy(p, out=out)
        return out
    else:
        return jnp.squeeze(p)