jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

Add type promotion truncation warning. #9931

Open YouJiacheng opened 2 years ago

YouJiacheng commented 2 years ago
>>> jnp.int64(0)
UserWarning: Explicitly requested dtype int64 requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
>>> jnp.dtype(jnp.uint16(0) + jnp.int16(0))
dtype('int32')
>>> jnp.dtype(jnp.uint32(0) + jnp.int32(0))
dtype('int32') # should raise a warning
>>> jnp.uint32(2 ** 32 - 1)
DeviceArray(4294967295, dtype=uint32)
>>> jnp.uint32(2 ** 32 - 1) + jnp.int32(0)
DeviceArray(-1, dtype=int32)

With JAX_ENABLE_X64=True

>>> jnp.uint32(0) + jnp.int32(0)
DeviceArray(0, dtype=int64)

In this table i4 + u4 should produce i8, but without jax_enable_x64 it will be truncated to i4 silently.

jakevdp commented 2 years ago

I may look into this, but at first blush I doubt it is feasible to add a useful promotion truncation warning in the presence of x64=False. The X64 flag is a huge hammer and truncates 64-bit types virtually everywhere throughout the package. Warnings from library code would overwhelm any warnings that users have control over. That said, we do have ideas for similar warning flags as we look to remove the X64 flag entirely.

YouJiacheng commented 2 years ago

@jakevdp Hmm. Is there are many promotion in the package?