$ py -c "import numpy as np; print(np.bincount(np.ones(3, bool)))"
[0 3]
$ py -c "from jax import numpy as jnp; print(jnp.bincount(jnp.ones(3, bool)))"
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 3068, in bincount
raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
TypeError: x argument to bincount must have an integer type; got bool
System info (python version, jaxlib version, accelerator, etc.)
Description
numpy.bincount accepts bool, but jax.numpy.bincount does not:
System info (python version, jaxlib version, accelerator, etc.)