jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

Better error message when setting `jax_enable_x64` using a 0/1 integer (after jax 0.4.26) #20611

Open fehiepsi opened 7 months ago

fehiepsi commented 7 months ago

Description

Currently, setting jax.config.update("jax_enable_x64", 0) will raise the error

Traceback (most recent call last):
  File "/home/runner/work/numpyro/numpyro/examples/hsgp.py", line 572, in <module>
    numpyro.enable_x64(args.x64)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpyro/util.py", line 48, in enable_x64
    jax.config.update("jax_enable_x64", use_x64)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/jax/_src/config.py", line 88, in update
    self._value_holders[name]._set(val)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/jax/_src/config.py", line 273, in _set
    self._update_global_hook(value)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/jax/_src/config.py", line 1237, in _update_x64_global
    lib.jax_jit.global_state().enable_x64 = val
TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (self, arg: Optional[bool]) -> None

Invoked with types: jaxlib.xla_extension.jax_jit.JitState, int

System info (python version, jaxlib version, accelerator, etc.)

System info image

yashk2810 commented 7 months ago

Do you have a repro?

fehiepsi commented 7 months ago

You can see the error here https://github.com/pyro-ppl/numpyro/actions/runs/8561672555/job/23463489995?pr=1775 I can also reproduce it locally.

jakevdp commented 7 months ago

Hi - I'm not able to reproduce this with the 0.4.26 release. This is the code I ran:

import jax
print(jax.__version__)
jax.config.update('jax_enable_x64', True)
print(jax.numpy.arange(10).dtype)
0.4.26
int64

We did see an issue like this in an unreleased version at HEAD, but it was fixed by https://github.com/google/jax/commit/21656115847079981e3915f88ab4533790970f53.

Can you double-check that you're actually using the release, and not an unreleased version installed before this commit?

yashk2810 commented 7 months ago

Sure but can you give me a minimal repro (preferably JAX only) that I can run and debug?

jakevdp commented 7 months ago

Ah, I can reproduce this way:

jax.config.update('jax_enable_x64', 1)

This comes from the tighter type validation for configuration flags, added in #19745. jax_enable_x64 requires a boolean argument, and it appears you are passing an integer.

Sorry for this breakage – I think this is an example of Hyrum's law in action: I'm not sure we ever intended to support integer arguments to boolean flags, but we inadvertently did.

The best fix would be to use boolean inputs when setting boolean flags, but we could probably loosen this too (cc/ @hawkinsp)

fehiepsi commented 7 months ago

Sorry, I thought that with the new version, setting jax.config.update(...) would raise the issue. It turns out that, as @jakevdp pointed out (thanks Jake!), I provided an integer there. Maybe a better error message would be helpful.