Open fehiepsi opened 7 months ago
Do you have a repro?
You can see the error here https://github.com/pyro-ppl/numpyro/actions/runs/8561672555/job/23463489995?pr=1775 I can also reproduce it locally.
Hi - I'm not able to reproduce this with the 0.4.26 release. This is the code I ran:
import jax
print(jax.__version__)
jax.config.update('jax_enable_x64', True)
print(jax.numpy.arange(10).dtype)
0.4.26
int64
We did see an issue like this in an unreleased version at HEAD, but it was fixed by https://github.com/google/jax/commit/21656115847079981e3915f88ab4533790970f53.
Can you double-check that you're actually using the release, and not an unreleased version installed before this commit?
Sure but can you give me a minimal repro (preferably JAX only) that I can run and debug?
Ah, I can reproduce this way:
jax.config.update('jax_enable_x64', 1)
This comes from the tighter type validation for configuration flags, added in #19745. jax_enable_x64
requires a boolean argument, and it appears you are passing an integer.
Sorry for this breakage – I think this is an example of Hyrum's law in action: I'm not sure we ever intended to support integer arguments to boolean flags, but we inadvertently did.
The best fix would be to use boolean inputs when setting boolean flags, but we could probably loosen this too (cc/ @hawkinsp)
Sorry, I thought that with the new version, setting jax.config.update(...)
would raise the issue. It turns out that, as @jakevdp pointed out (thanks Jake!), I provided an integer there. Maybe a better error message would be helpful.
Description
Currently, setting
jax.config.update("jax_enable_x64", 0)
will raise the errorSystem info (python version, jaxlib version, accelerator, etc.)
System info