Open xichaoqiang opened 2 years ago
Thanks for the request. Related feature requests are at #9956 and #11002.
Seconded - hoping to use J1, J0 in simple interferometry applications, without necessarily needing the higher order. Is there a reason these are hard to do in Jax?
Hi, I'd love to see this implementation. I had to move to another framework because I needed J0 and J1 for a nerf related project.
Seconding the request for J0, J1 functions! Would implementing these in Jax be similar to other packages such as scipy?
I actually see Bessel Jv has now been implemented! But it seems to be numerically unstable:
def j1(x):
# jax version
return jax.scipy.special.bessel_jn(x,v=1)[1]
from scipy.special import j1 as j1s # scipy version
z = np.linspace(0,100,1000)
plt.plot(z,vmap(j1)(z))
plt.plot(z,j1s(z))
This is improved by moving to higher n_iter
.
Hi Notice also instabilities at high x
which incline to use high n_iter but this is in conjunction withnan
at low x
:
N=300
x = jnp.linspace(0,N,50*N)
jscs.bessel_jn(x, v=0, n_iter=200)
leads to
DeviceArray([[ nan, nan, nan, ..., 0.88876336,
0.89054682, 0.8923063 ]], dtype=float64)
while
plot(x,jscs.bessel_jn(x, v=0, n_iter=200).squeeze())
plot(x,scs.jn(0,x),ls="--") # scipy.special.jn
gives
cc/ @tlu7
Here may be a solution for J0(x) using discussion #14132 and Boost code;
def _evaluate_rational(z, num, denom):
count = len(num)
def true_fn_update(z):
def body_true(val1):
# decode val1
s1, s2, i = val1
s1 *= z; s2 *= z; s1 += num[i]; s2 += denom[i];
return s1, s2, i-1
def cond_true(val1):
s1, s2, i = val1
return i>=0
val1_init = (num[-1], denom[-1], count-2)
s1, s2, _ = jax.lax.while_loop(cond_true, body_true, val1_init)
return s1/s2
def false_fn_update(z):
def body_false(val1):
# decode val1
s1, s2, i = val1
s1 *= z; s2 *= z; s1 += num[i]; s2 += denom[i];
return s1, s2, i+1
def cond_false(val1):
s1, s2, i = val1
return i<count
val1_init = (num[0], denom[0],1)
s1, s2, _ = jax.lax.while_loop(cond_false, body_false, val1_init)
return s1/s2
return jnp.where(z<=1, true_fn_update(z), false_fn_update(1/z))
v_ratio = jax.vmap(_evaluate_rational, in_axes=(0,None, None))
def J0(x):
P1 = jnp.array([-4.1298668500990866786e+11,
2.7282507878605942706e+10,
-6.2140700423540120665e+08,
6.6302997904833794242e+06,
-3.6629814655107086448e+04,
1.0344222815443188943e+02,
-1.2117036164593528341e-01])
Q1 = jnp.array([2.3883787996332290397e+12,
2.6328198300859648632e+10,
1.3985097372263433271e+08,
4.5612696224219938200e+05,
9.3614022392337710626e+02,
1.0,
0.0])
assert len(P1) == len(Q1)
P2 = jnp.array([-1.8319397969392084011e+03,
-1.2254078161378989535e+04,
-7.2879702464464618998e+03,
1.0341910641583726701e+04,
1.1725046279757103576e+04,
4.4176707025325087628e+03,
7.4321196680624245801e+02,
4.8591703355916499363e+01])
Q2 = jnp.array([-3.5783478026152301072e+05,
2.4599102262586308984e+05,
-8.4055062591169562211e+04,
1.8680990008359188352e+04,
-2.9458766545509337327e+03,
3.3307310774649071172e+02,
-2.5258076240801555057e+01,
1.0])
assert len(P2) == len(Q2)
PC = jnp.array([2.2779090197304684302e+04,
4.1345386639580765797e+04,
2.1170523380864944322e+04,
3.4806486443249270347e+03,
1.5376201909008354296e+02,
8.8961548424210455236e-01])
QC = jnp.array([2.2779090197304684318e+04,
4.1370412495510416640e+04,
2.1215350561880115730e+04,
3.5028735138235608207e+03,
1.5711159858080893649e+02,
1.0])
assert len(PC) == len(QC)
PS = jnp.array([-8.9226600200800094098e+01,
-1.8591953644342993800e+02,
-1.1183429920482737611e+02,
-2.2300261666214198472e+01,
-1.2441026745835638459e+00,
-8.8033303048680751817e-03])
QS = jnp.array([5.7105024128512061905e+03,
1.1951131543434613647e+04,
7.2642780169211018836e+03,
1.4887231232283756582e+03,
9.0593769594993125859e+01,
1.0])
assert len(PS) == len(QS)
x1 = 2.4048255576957727686e+00
x2 = 5.5200781102863106496e+00
x11 = 6.160e+02
x12 = -1.42444230422723137837e-03
x21 = 1.4130e+03
x22 = 5.46860286310649596604e-04
one_div_root_pi = 5.641895835477562869480794515607725858e-01
def t1(x): # x<=4
y = x * x
r = v_ratio(y, P1, Q1)
factor = (x + x1) * ((x - x11/256) - x12);
return factor * r
def t2(x): # x<=8
y = 1 - (x * x)/64
r = v_ratio(y, P2, Q2)
factor = (x + x2) * ((x - x21/256) - x22)
return factor * r
def t3(x): #x>8
y = 8 / x
y2 = y * y
rc = v_ratio(y2, PC, QC)
rs = v_ratio(y2, PS, QS)
factor = one_div_root_pi / jnp.sqrt(x)
sx = jnp.sin(x)
cx = jnp.cos(x)
return factor * (rc * (cx + sx) - y * rs * (sx - cx))
x = jnp.abs(x)
return jnp.select(
[x == 0, x <= 4, x <= 8, x>8],
[1, t1(x), t2(x), t3(x)],
default = x)
Test:
plt.plot(x,J0(x))
plt.plot(x, sc.special.j0(x), ls="--")
plt.plot(x, J0(x)-sc.special.j0(x))
If you can simplify the code and improve it , I would be glad... Or may be it is already a better implementation. Moreover, Scipy may be not the most accurate to compare with.
Just bumping this thread--the current implementation of J0 and J1 yield nans and have numerically unstable gradients:
Here is a MWE showing the issue:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
from jax import jacfwd, jacrev, grad
from scipy.special import j0, j1, jn
def j1_jax(x):
return jax.scipy.special.bessel_jn(x,v=1,n_iter=50)[1]
def j0_jax(x):
return jax.scipy.special.bessel_jn(x,v=0,n_iter=50)[0]
grad_jax = vmap(grad(j1_jax))
def grad_scipy(x):
return (j0(x) - jn(2,x))/(2)
x = jnp.linspace(-10*jnp.pi,10*jnp.pi,1000)
fig, (ax1,ax2) = plt.subplots(2,1, figsize=(10,10))
ax1.set_title("J1 comparison")
ax1.plot(x,j1(x), label='scipy J1')
ax1.plot(x,j1_jax(x), label='jax J1', zorder=2)
ax1.legend()
ax2.set_title("J1 Gradient Comparison")
ax2.plot(x, grad_scipy(x), label='analytic grad')
ax2.plot(x,grad_jax(x), label="jax grad", zorder=2)
ax2.legend()
Just to bump this - for our use case we seem ok with a bit of a hack, patching together the core Jax bessel function for small x with a different series expansion for large x. We currently mainly only need J1 out to ~ moderate values, but a similar approach can be used to generate up to arbitrary Jn and perhaps should be implemented.
I should add that the core Jax implementation is very numerically unstable in 32 bit but much better in 64 bit.
Quick demo:
https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb
Just to update - I went the full way and just implemented the CEPHES version of J1 in pure-Jax, and it is now much faster and more numerically stable. Can do this for J0 too - @jecampagne has taken a similar approach there.
Thoughts?
I've now done a full pure-Jax implementation of the Scipy Bessel function routine for J0 and J1, and a non-iterative version of Jv which uses the first recurrence relation to construct v > 1 from J0 and J1 via J_v+1 = 2v/x Jv(x) - J_v-1(x). It does have poorer precision close to x=0 this way but is at machine precision wrt scipy for larger values and has good speed.
https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb
@jakevdp - what do you think? I would do a pull request - wanted to check this is the sort of thing you might want.
Just seconding to say I would love this functionality - I'm trying to use J1 for some heat equation simulations. Thanks for the links already @benjaminpope!
Assigning to @jakevdp for further triage.
Anyway yes happy to do the PR
Feature Request bessel function in scipy ( J0 J1 J2 Y0 Y1 Y2) which have applied in autograd.