jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

Feature Request bessel function in scipy ( J0 J1 J2 Y0 Y1 Y2) which have applied in autograd. #12402

Open xichaoqiang opened 2 years ago

xichaoqiang commented 2 years ago

Feature Request bessel function in scipy ( J0 J1 J2 Y0 Y1 Y2) which have applied in autograd.

image

jakevdp commented 2 years ago

Thanks for the request. Related feature requests are at #9956 and #11002.

benjaminpope commented 2 years ago

Seconded - hoping to use J1, J0 in simple interferometry applications, without necessarily needing the higher order. Is there a reason these are hard to do in Jax?

KostadinovShalon commented 2 years ago

Hi, I'd love to see this implementation. I had to move to another framework because I needed J0 and J1 for a nerf related project.

LouisDesdoigts commented 1 year ago

Seconding the request for J0, J1 functions! Would implementing these in Jax be similar to other packages such as scipy?

benjaminpope commented 1 year ago

I actually see Bessel Jv has now been implemented! But it seems to be numerically unstable:


def j1(x):
    # jax version
    return jax.scipy.special.bessel_jn(x,v=1)[1]

from scipy.special import j1 as j1s # scipy version

z = np.linspace(0,100,1000)

plt.plot(z,vmap(j1)(z))
plt.plot(z,j1s(z))

This is improved by moving to higher n_iter.

image

jecampagne commented 1 year ago

Hi Notice also instabilities at high x which incline to use high n_iter but this is in conjunction withnan at low x:

N=300
x = jnp.linspace(0,N,50*N)
jscs.bessel_jn(x, v=0, n_iter=200)

leads to

DeviceArray([[       nan,        nan,        nan, ..., 0.88876336,
              0.89054682, 0.8923063 ]], dtype=float64)

while

plot(x,jscs.bessel_jn(x, v=0, n_iter=200).squeeze())
plot(x,scs.jn(0,x),ls="--") # scipy.special.jn

gives image

jakevdp commented 1 year ago

cc/ @tlu7

jecampagne commented 1 year ago

Here may be a solution for J0(x) using discussion #14132 and Boost code;

def _evaluate_rational(z, num, denom):

    count = len(num)

    def true_fn_update(z):

      def body_true(val1):
        # decode val1
        s1, s2, i = val1 
        s1 *= z; s2 *= z; s1 += num[i]; s2 += denom[i];
        return s1, s2, i-1

      def cond_true(val1):
        s1, s2, i = val1 
        return i>=0

      val1_init = (num[-1], denom[-1], count-2)
      s1, s2, _ = jax.lax.while_loop(cond_true, body_true, val1_init)

      return s1/s2

    def false_fn_update(z):
      def body_false(val1):
        # decode val1
        s1, s2, i = val1 
        s1 *= z; s2 *= z; s1 += num[i]; s2 += denom[i];
        return s1, s2, i+1

      def cond_false(val1):
        s1, s2, i = val1 
        return i<count

      val1_init = (num[0], denom[0],1)
      s1, s2, _ = jax.lax.while_loop(cond_false, body_false, val1_init)

      return s1/s2

    return jnp.where(z<=1, true_fn_update(z), false_fn_update(1/z))  

v_ratio = jax.vmap(_evaluate_rational, in_axes=(0,None, None))

def J0(x):
  P1 = jnp.array([-4.1298668500990866786e+11,
                  2.7282507878605942706e+10,
                  -6.2140700423540120665e+08,
                  6.6302997904833794242e+06,
                  -3.6629814655107086448e+04,
                  1.0344222815443188943e+02,
                  -1.2117036164593528341e-01]) 
  Q1 = jnp.array([2.3883787996332290397e+12,
                  2.6328198300859648632e+10,
                  1.3985097372263433271e+08,
                  4.5612696224219938200e+05,
                  9.3614022392337710626e+02,
                  1.0,
                  0.0])
  assert len(P1) == len(Q1)

  P2 = jnp.array([-1.8319397969392084011e+03,
                  -1.2254078161378989535e+04,
                  -7.2879702464464618998e+03,
                  1.0341910641583726701e+04,
                  1.1725046279757103576e+04,
                  4.4176707025325087628e+03,
                  7.4321196680624245801e+02,
                  4.8591703355916499363e+01])
  Q2 = jnp.array([-3.5783478026152301072e+05,
                  2.4599102262586308984e+05,
                  -8.4055062591169562211e+04,
                  1.8680990008359188352e+04,
                  -2.9458766545509337327e+03,
                  3.3307310774649071172e+02,
                  -2.5258076240801555057e+01,
                  1.0])
  assert len(P2) == len(Q2)

  PC = jnp.array([2.2779090197304684302e+04,
                  4.1345386639580765797e+04,
                  2.1170523380864944322e+04,
                  3.4806486443249270347e+03,
                  1.5376201909008354296e+02,
                  8.8961548424210455236e-01])
  QC = jnp.array([2.2779090197304684318e+04,
                  4.1370412495510416640e+04,
                  2.1215350561880115730e+04,
                  3.5028735138235608207e+03,
                  1.5711159858080893649e+02,
                  1.0])

  assert len(PC) == len(QC)

  PS = jnp.array([-8.9226600200800094098e+01,
                  -1.8591953644342993800e+02,
                  -1.1183429920482737611e+02,
                  -2.2300261666214198472e+01,
                  -1.2441026745835638459e+00,
                  -8.8033303048680751817e-03])
  QS = jnp.array([5.7105024128512061905e+03,
                  1.1951131543434613647e+04,
                  7.2642780169211018836e+03,
                  1.4887231232283756582e+03,
                  9.0593769594993125859e+01,
                  1.0])
  assert len(PS) == len(QS)

  x1 = 2.4048255576957727686e+00
  x2 = 5.5200781102863106496e+00
  x11 = 6.160e+02
  x12 = -1.42444230422723137837e-03
  x21 = 1.4130e+03
  x22 = 5.46860286310649596604e-04
  one_div_root_pi =  5.641895835477562869480794515607725858e-01

  def t1(x):  # x<=4
    y = x * x
    r = v_ratio(y, P1, Q1)
    factor = (x + x1) * ((x - x11/256) - x12);
    return factor * r

  def t2(x): # x<=8
    y = 1 - (x * x)/64
    r = v_ratio(y, P2, Q2)
    factor = (x + x2) * ((x - x21/256) - x22)
    return factor * r

  def t3(x): #x>8
      y = 8 / x
      y2 = y * y
      rc = v_ratio(y2, PC, QC)
      rs = v_ratio(y2, PS, QS)
      factor = one_div_root_pi / jnp.sqrt(x)
      sx = jnp.sin(x)
      cx = jnp.cos(x)
      return factor * (rc * (cx + sx) - y * rs * (sx - cx))

  x = jnp.abs(x)
  return jnp.select(
      [x == 0, x <= 4, x <= 8, x>8],
      [1, t1(x), t2(x), t3(x)],
      default = x)

Test:

plt.plot(x,J0(x))
plt.plot(x, sc.special.j0(x), ls="--")

image

plt.plot(x, J0(x)-sc.special.j0(x))

image

jecampagne commented 1 year ago

If you can simplify the code and improve it , I would be glad... Or may be it is already a better implementation. Moreover, Scipy may be not the most accurate to compare with.

shashankdholakia commented 1 year ago

Just bumping this thread--the current implementation of J0 and J1 yield nans and have numerically unstable gradients:

Here is a MWE showing the issue:

import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
from jax import jacfwd, jacrev, grad

from scipy.special import j0, j1, jn

def j1_jax(x):
    return jax.scipy.special.bessel_jn(x,v=1,n_iter=50)[1]

def j0_jax(x):
    return jax.scipy.special.bessel_jn(x,v=0,n_iter=50)[0]

grad_jax = vmap(grad(j1_jax))

def grad_scipy(x):
    return (j0(x) - jn(2,x))/(2)

x = jnp.linspace(-10*jnp.pi,10*jnp.pi,1000)

fig, (ax1,ax2) = plt.subplots(2,1, figsize=(10,10))
ax1.set_title("J1 comparison")
ax1.plot(x,j1(x), label='scipy J1')
ax1.plot(x,j1_jax(x), label='jax J1', zorder=2)
ax1.legend()

ax2.set_title("J1 Gradient Comparison")
ax2.plot(x, grad_scipy(x), label='analytic grad')
ax2.plot(x,grad_jax(x), label="jax grad", zorder=2)
ax2.legend()

image

benjaminpope commented 1 year ago

Just to bump this - for our use case we seem ok with a bit of a hack, patching together the core Jax bessel function for small x with a different series expansion for large x. We currently mainly only need J1 out to ~ moderate values, but a similar approach can be used to generate up to arbitrary Jn and perhaps should be implemented.

I should add that the core Jax implementation is very numerically unstable in 32 bit but much better in 64 bit.

Quick demo:

https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb

benjaminpope commented 1 year ago

Just to update - I went the full way and just implemented the CEPHES version of J1 in pure-Jax, and it is now much faster and more numerically stable. Can do this for J0 too - @jecampagne has taken a similar approach there.

Thoughts?

benjaminpope commented 1 year ago

I've now done a full pure-Jax implementation of the Scipy Bessel function routine for J0 and J1, and a non-iterative version of Jv which uses the first recurrence relation to construct v > 1 from J0 and J1 via J_v+1 = 2v/x Jv(x) - J_v-1(x). It does have poorer precision close to x=0 this way but is at machine precision wrt scipy for larger values and has good speed.

https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb

@jakevdp - what do you think? I would do a pull request - wanted to check this is the sort of thing you might want.

femtomc commented 1 year ago

Just seconding to say I would love this functionality - I'm trying to use J1 for some heat equation simulations. Thanks for the links already @benjaminpope!

axch commented 1 year ago

Assigning to @jakevdp for further triage.

benjaminpope commented 1 year ago

Anyway yes happy to do the PR