jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.26k stars 2.77k forks source link

float128 please?? #2007

Open mc2engel opened 4 years ago

mc2engel commented 4 years ago

Hello!

I'm using JAX to implement MCMC on a 2D Ising lattice. I need to be able to compute things like 1.0 - x for very small x, <1e-15. Would it be possible for you to implement a float128 data type? It would seriously improve my life!

Thank you, Megan

hawkinsp commented 4 years ago

I'm curious whether by float128 you mean what NumPy calls float128 (which is really either 80-bit or 64-bit): https://docs.scipy.org/doc/numpy-1.15.0/user/basics.types.html#extended-precision or whether you mean IEEE754 quad precision types.

Do you care about using float128 inside jit?

mc2engel commented 4 years ago

Thanks for the quick reply. I meant what NumPy calls float128 (on my Macbook Pro I guess this is 80 bits); hopefully that will be sufficient for my purposes. I want to use it potentially inside jit, but certainly inside grad(). Ultimately I want to differentiate a dynamical simulation of the Ising model, comprised of sequential spin flips on a 2D lattice. If you have better/alternative ideas for dealing with numbers that are very close to 1, please let me know!!

mattjj commented 4 years ago

If you're only using grad and only using CPU then a stop-gap solution would be to use Autograd, which works on top of NumPy and thus might support those dtypes. On CPU it can also be faster than JAX-without-jit (until we lower our dispatch overheads).

@hawkinsp understands much more than me about how feasible it is to add support in JAX.

mc2engel commented 4 years ago

Ah, I'm hoping to extend my framework eventually to molecular-dynamics-style simulations, so I'd prefer to stick with JAX (and soon, JAX-MD).

hawkinsp commented 4 years ago

We have a few options here. One would be to plumb float80 or float128 support into XLA.

Another would be to implement a precision-doubling transformation (i.e., build a doubledouble implementation in JAX as a JAX transformation.)

This might not be something we get to immediately (the latter sounds fun though), contributions are welcome.

jeffgortmaker commented 4 years ago

+1 For extended precision. I've been using JAX a bit in applications with near-step functions and integral approximations where underflow can create issues.

An extended precision JAX transformation that works with jit and grad would be very cool.

sschoenholz commented 4 years ago

Another option here that could be pretty cool (not that I have time to look into it) would be to try to implement a JAX => mpmath transformation (http://mpmath.org/) for turning calculations into infinite-precision calculations. It would probably be a bit slow to do real work in, but good for debugging numerical precision errors.

jakevdp commented 4 years ago

I started working on a potential solution to this problem here: https://github.com/google/jax/pull/3465

jofrevalles commented 1 year ago

@jakevdp Do we have any update on this? It seems that the doubledouble() experimental function was removed in #6530 . We are currently using jax.config.update("jax_enable_x64", True) to force jax into using double-precision floating-point numbers, but we would like to go up to 128 instead of 64.

Thanks!

jakevdp commented 1 year ago

No, there's no update to this. GIven that there's no native support for float128 in XLA, adding it to JAX would be a big project, and I don't anticipate it happening any time soon.