Closed honno closed 1 year ago
IIRC I believe it's because it's pretty common to use DAZ (denormals-are-zero) on accelerators like GPUs.
Yes. This is working as intended: JAX sets the CPU mode that flushes denormals to zero. This is both because some accelerators (e.g., TPU) do not support denormals, and even on CPU it is frequently faster to avoid denormals.
Denormals are only flushed when you compute with them. If you simply stick some denormals in an array and then print it, we don't flush the denormals. If you did any computation on them we would.
I'm not sure that having repr
also flush would really be that clarifying. Indeed, it might be more confusing in some ways, because there are still ways to observe that the denormals aren't actually zero (e.g., bitcast them to an integer and look at the bits). What do you think?
If you simply stick some denormals in an array and then print it, we don't flush the denormals. If you did any computation on them we would.
Ah thanks for the explaination, this behaviour makes sense now I think about it.
because there are still ways to observe that the denormals aren't actually zero (e.g., bitcast them to an integer and look at the bits)
Oh good shout, I'll have a look into this.
Description
Subnormal values are seemingly representable by
jax.numpy
, but erroneously equal to 0.Same with
```python >>> x = jnp.asarray(5e-324, dtype=jnp.float64) >>> x Array(5.e-324, dtype=float64) >>> x == 0 Array(True, dtype=bool) ```jnp.float64
This might be a wont-fix kind of issue? Given subnormals might possibly trump NaNs in terms of floating-point behaviour which makes you want to tear your hair out :sweat_smile:
Or if the subnormals aren't actually being represented, I suppose this issue is instead a feature request to change the
repr()
.What jax/jaxlib version are you using?
jax
built fromHEAD
,jaxlib==0.4.11
Which accelerator(s) are you using?
CPU
Additional system info
py10, Ubuntu 22.04