google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

[FEATURE] Real-valued argument Fresnel integrals (add missing `jax.scipy.special.fresnel`) #22683

Closed jeertmans closed 5 days ago

jeertmans commented 1 month ago

Hello!

I really enjoy using JAX for my research, but I sometimes encounter a few missing functions from jax.numpy and jax.scipy, that I which were there. One important function for my research is scipy.special.fresnel, so I decided to translate SciPy's C++ code into JAX-compatible code (see here), at least for real-valued arguments. For complex arguments, I rely on a custom complex-valued error function implementation, but its accuracy is not good.

However, the real-valued variant is very accurate (the diff. with SciPy's impl. is usually < 1e-14 using double precision). Currently, the implementation is very close to that of SciPy, but it can of course be optimized (many computations are repeated, but are maybe merged with @jit).

If you are interested, I'd be super happy to make a PR to implement the missing jax.scipy.special.fresnel, at least for real-valued arguments.

What do you think?

jakevdp commented 1 month ago

I think we'd be happy to accept a PR for the real-valued case. We have some notes on contributing here: https://jax.readthedocs.io/en/latest/contributing.html#contributing-code-using-pull-requests Please let me know if you have any questions!

jeertmans commented 1 month ago

Thanks you for the comment! I will do that later this week and link to this issue when ready :)