Open mc2engel opened 4 years ago
I'm curious whether by float128
you mean what NumPy calls float128
(which is really either 80-bit or 64-bit):
https://docs.scipy.org/doc/numpy-1.15.0/user/basics.types.html#extended-precision
or whether you mean IEEE754 quad
precision types.
Do you care about using float128
inside jit
?
Thanks for the quick reply. I meant what NumPy calls float128 (on my Macbook Pro I guess this is 80 bits); hopefully that will be sufficient for my purposes. I want to use it potentially inside jit, but certainly inside grad(). Ultimately I want to differentiate a dynamical simulation of the Ising model, comprised of sequential spin flips on a 2D lattice. If you have better/alternative ideas for dealing with numbers that are very close to 1, please let me know!!
If you're only using grad
and only using CPU then a stop-gap solution would be to use Autograd, which works on top of NumPy and thus might support those dtypes. On CPU it can also be faster than JAX-without-jit
(until we lower our dispatch overheads).
@hawkinsp understands much more than me about how feasible it is to add support in JAX.
Ah, I'm hoping to extend my framework eventually to molecular-dynamics-style simulations, so I'd prefer to stick with JAX (and soon, JAX-MD).
We have a few options here. One would be to plumb float80 or float128 support into XLA.
Another would be to implement a precision-doubling transformation (i.e., build a doubledouble
implementation in JAX as a JAX transformation.)
This might not be something we get to immediately (the latter sounds fun though), contributions are welcome.
+1 For extended precision. I've been using JAX a bit in applications with near-step functions and integral approximations where underflow can create issues.
An extended precision JAX transformation that works with jit
and grad
would be very cool.
Another option here that could be pretty cool (not that I have time to look into it) would be to try to implement a JAX => mpmath transformation (http://mpmath.org/) for turning calculations into infinite-precision calculations. It would probably be a bit slow to do real work in, but good for debugging numerical precision errors.
I started working on a potential solution to this problem here: https://github.com/google/jax/pull/3465
@jakevdp Do we have any update on this? It seems that the doubledouble()
experimental function was removed in #6530 .
We are currently using jax.config.update("jax_enable_x64", True)
to force jax into using double-precision floating-point numbers, but we would like to go up to 128 instead of 64.
Thanks!
No, there's no update to this. GIven that there's no native support for float128 in XLA, adding it to JAX would be a big project, and I don't anticipate it happening any time soon.
Hello!
I'm using JAX to implement MCMC on a 2D Ising lattice. I need to be able to compute things like 1.0 - x for very small x, <1e-15. Would it be possible for you to implement a float128 data type? It would seriously improve my life!
Thank you, Megan