f0uriest / interpax

Interpolation and function approximation with JAX
MIT License
128 stars 11 forks source link

Slow jax.grad of CubicSpline #41

Open odstrcilt opened 1 month ago

odstrcilt commented 1 month ago

Interfax is a very nice package. However, if it is included in the code wrapped in jax.grad, it is about 10x slower than my simple cubic spline interpolation. Interpolation itself is not significantly different. Code was written with the help of chatGPT. It is not a high-priority issue.


#tridiagonal_solve supporting autodiff
@jax.jit
def tridiagonal_solve(a, b, c, d):
    """
    Solves a tridiagonal system Ax = d for x, where A is a tridiagonal matrix with
    diagonals a, b, and c.
    a - subdiagonal (length n-1)
    b - main diagonal (length n)
    c - superdiagonal (length n-1)
    d - right-hand side vector (length n)
    """
    n = len(d)
    ac, bc, cc, dc = map(jnp.array, (a, b, c, d))  # Ensure all inputs are arrays

    # Forward sweep using lax.fori_loop
    def forward_sweep(i, val):
        bc, dc = val
        w = ac[i-1] / bc[i-1]
        bc = bc.at[i].set(bc[i] - w * cc[i-1])
        dc = dc.at[i].set(dc[i] - w * dc[i-1])
        return bc, dc

    bc, dc = jax.lax.fori_loop(1, n, forward_sweep, (bc, dc))

    # Back substitution using lax.fori_loop
    def back_substitution(i, xc):
        idx = n - 2 - i
        xc = xc.at[idx].set((dc[idx] - cc[idx] * xc[idx+1]) / bc[idx])
        return xc

    xc = jnp.zeros_like(dc)
    xc = xc.at[-1].set(dc[-1] / bc[-1])
    xc = jax.lax.fori_loop(0, n-1, back_substitution, xc)

    return xc

#jax implementation of cubic spline
@jax.jit
def compute_cubic_spline_coeffs(x, y):
    n = len(x)
    h = x[1:] - x[:-1]

    # Create the tridiagonal matrix A components
    lower = jnp.zeros(n)
    upper = jnp.zeros(n)
    diag = jnp.ones(n)

    lower=lower.at[:-2].set(h[:-1])
    upper = upper.at[2:].set(h[1:])
    diag = diag.at[1:-1].set((h[:-1] + h[1:]) * 2)

    # Right-hand side vector B
    B = jnp.zeros(n)
    B = B.at[1:-1].set(3 * ((y[2:] - y[1:-1]) / h[1:] - (y[1:-1] - y[:-2]) / h[:-1]))

    c = tridiagonal_solve(lower[:-1], diag, upper[1:], B)

    # Compute b and d
    b = (y[1:] - y[:-1]) / h - h * (2 * c[:-1] + c[1:]) / 3
    d = (c[1:] - c[:-1]) / (3 * h)

    return x, y[:-1], b, c[:-1], d

@jax.jit
def evaluate_cubic_spline( coeffs, x_new):
    x, a, b, c, d = coeffs
    idx = jnp.searchsorted(x, x_new) - 1
    idx = jnp.clip(idx, 0, len(a) - 1)
    dx = x_new - x[idx]

    return a[idx] + dx * (b[idx] + dx * (c[idx] + d[idx] * dx))

@jax.jit
def cubic_spline(xi, x, y):
    c = compute_cubic_spline_coeffs(x, y)
    return evaluate_cubic_spline( c, xi)

@jax.jit
def integrate_cubic_spline(coeffs, A, B):
    #assume A <= B
    #NOTE Extrapolates spline by zeros on both ends

    x, a, b, c, d  = coeffs

    # Find the indices of intervals containing a and b
    idx = jnp.searchsorted(x, jnp.array([A,B])) - 1

    # Clip indices to ensure they are within bounds
    idx = jnp.clip(idx, 0, len(x) - 2)

    # Calculate the integration limits for each interval
    dx_a = A-x[idx[0]]

    x_clipped = jnp.clip(x, A, B)
    dx_b = x_clipped[1:]-x_clipped[:-1]
    dx_b = dx_b.at[idx[0]].add(dx_a)

    ## Calculate the integral contribution for each interval
    int_b = jnp.zeros_like(dx_b)
    int_a = 0
    for i, coeff in enumerate([d,c,b,a]):
        int_b = (int_b + coeff/(4-i)) * dx_b
        int_a = (int_a + coeff[idx[0]]/(4-i)) * dx_a

    ## Sum up the integral contributions from all intervals
    total_integral = jnp.sum(int_b)-int_a

    return total_integral
f0uriest commented 1 month ago

Do you have an example of the code above being used compared to interpax? Also, are you running on CPU or GPU? On CPU the tridiagonal solve is likely faster, but on GPU the loops will cause a lot of overhead compared to just calling out to cusolver on a full matrix, though if that's the case I would expect to see a performance difference in the forward pass not just in the gradient.