f0uriest / interpax

Interpolation and function approximation with JAX
MIT License
137 stars 11 forks source link

CubicSpline not supporting grad JIT. #32

Closed blackening closed 5 months ago

blackening commented 6 months ago

Jax: 0.4.28 Interpax: 0.3.1

Minimal reproduction:

import jax
import jax.numpy as jnp
from interpax import CubicSpline

x0, y0 = 0,0
x1, y1 = 1,-0.3

xm, ym = 0.5, 1

N = 6
p_x = jnp.linspace(x0, x1, N)
p_y = jnp.linspace(y0, y1, N) # spline control points, force them into a line at the start

def loss(p_y):
    f = CubicSpline(p_x, p_y,check=False) # My own guess that the check could have been at fault. Makes no difference.

    return (f(x0)-y0)**2 + (f(x1)-y1)**2 + (f(xm)-ym)**2

print('loss', loss(p_y))

dloss = jax.grad(loss)

print('dloss', dloss(p_y)) # Works fine

jdloss = jax.jit(jax.grad(loss))

print('jdloss', jdloss(p_y)) #Fails to get here. 

Fails on: operation a:bool[6] = is_finite b from line C:....\interpax_test.py:15 (loss) See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

I have attached the full output. interpax_test_output.txt

unalmis commented 6 months ago

This is the same issue as #31 . The bugfix made in the PR mentioned there has yet to be pushed to interpax's PyPI package. Until then one can manually edit the interpax installation with the changes in the liked pull request.

f0uriest commented 5 months ago

Resolved with release v0.3.2