Bug- False "UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32 #1841
Python 3.9.19 (main, May 6 2024, 19:43:03)
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import tensorflow_probability.substrates.jax as tfp
>>> tfd = tfp.distributions
>>> rng = jax.random.PRNGKey(0)
>>> tfd.OneHotCategorical(logits=jax.random.normal(key=rng,shape=(3,4)), dtype=jax.numpy.float32).sample(seed=rng)
/home/user/anaconda3/envs/ml_exp/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
Array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.]], dtype=float32)
I explicitly set jnp.float32 and did not ask for int anywhere, but received the warning
I explicitly set
jnp.float32
and did not ask forint
anywhere, but received the warning