jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.3k stars 2.78k forks source link

Subnormal `jax.numpy` arrays equal to 0 #16200

Closed honno closed 1 year ago

honno commented 1 year ago

Description

Subnormal values are seemingly representable by jax.numpy, but erroneously equal to 0.

>>> from jax import numpy as jnp
>>> x = jnp.asarray(1.401298464324817e-45, dtype=jnp.float32)
>>> x
Array(1.e-45, dtype=float32)
>>> x == 0
Array(True, dtype=bool)
Same with jnp.float64 ```python >>> x = jnp.asarray(5e-324, dtype=jnp.float64) >>> x Array(5.e-324, dtype=float64) >>> x == 0 Array(True, dtype=bool) ```

This might be a wont-fix kind of issue? Given subnormals might possibly trump NaNs in terms of floating-point behaviour which makes you want to tear your hair out :sweat_smile:

Or if the subnormals aren't actually being represented, I suppose this issue is instead a feature request to change the repr().

What jax/jaxlib version are you using?

jax built from HEAD, jaxlib==0.4.11

Which accelerator(s) are you using?

CPU

Additional system info

py10, Ubuntu 22.04

patrick-kidger commented 1 year ago

IIRC I believe it's because it's pretty common to use DAZ (denormals-are-zero) on accelerators like GPUs.

My own workaround.

hawkinsp commented 1 year ago

Yes. This is working as intended: JAX sets the CPU mode that flushes denormals to zero. This is both because some accelerators (e.g., TPU) do not support denormals, and even on CPU it is frequently faster to avoid denormals.

Denormals are only flushed when you compute with them. If you simply stick some denormals in an array and then print it, we don't flush the denormals. If you did any computation on them we would.

I'm not sure that having repr also flush would really be that clarifying. Indeed, it might be more confusing in some ways, because there are still ways to observe that the denormals aren't actually zero (e.g., bitcast them to an integer and look at the bits). What do you think?

honno commented 1 year ago

If you simply stick some denormals in an array and then print it, we don't flush the denormals. If you did any computation on them we would.

Ah thanks for the explaination, this behaviour makes sense now I think about it.

because there are still ways to observe that the denormals aren't actually zero (e.g., bitcast them to an integer and look at the bits)

Oh good shout, I'll have a look into this.