tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.1k forks source link

Bug- False "UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32 #1841

Open eadadi opened 1 week ago

eadadi commented 1 week ago
Python 3.9.19 (main, May  6 2024, 19:43:03)
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import tensorflow_probability.substrates.jax as tfp
>>> tfd = tfp.distributions
>>> rng = jax.random.PRNGKey(0)
>>> tfd.OneHotCategorical(logits=jax.random.normal(key=rng,shape=(3,4)), dtype=jax.numpy.float32).sample(seed=rng)
/home/user/anaconda3/envs/ml_exp/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
Array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 0., 1.]], dtype=float32)

I explicitly set jnp.float32 and did not ask for int anywhere, but received the warning

eadadi commented 1 week ago

comment: I will post this bug on Jax' repo as well