jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

bincount rejects bool #24813

Closed carlosgmartin closed 1 day ago

carlosgmartin commented 4 days ago

Description

numpy.bincount accepts bool, but jax.numpy.bincount does not:

$ py -c "import numpy as np; print(np.bincount(np.ones(3, bool)))"
[0 3]
$ py -c "from jax import numpy as jnp; print(jnp.bincount(jnp.ones(3, bool)))"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 3068, in bincount
    raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
TypeError: x argument to bincount must have an integer type; got bool

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.34
numpy:  1.26.4
python: 3.12.7 (main, Oct  1 2024, 02:05:46) [Clang 15.0.0 (clang-1500.3.9.4)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Carloss-MacBook-Pro-2.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:46 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6031', machine='arm64')