google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.67k stars 2.71k forks source link

jax.numpy.digitize doesn't work with shape polymorphism #22489

Open tchatow opened 1 month ago

tchatow commented 1 month ago

Description

Error when tracing with shape polymorphism in jax.numpy.digitize

  File ".../lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 7548, in searchsorted
    dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
                     ^^^^^^
TypeError: '_DimExpr' object cannot be interpreted as an integer

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0
mattjj commented 1 month ago

Thanks for reporting this! Any chance you can share a repro?

tchatow commented 1 month ago

Here's a simple example

N, = jax.export.symbolic_shape("N")
f = jax.export.export(jax.jit(jnp.digitize))

shape0 = jax.ShapeDtypeStruct((10,), jnp.int32)
shape1 = jax.ShapeDtypeStruct((N,), jnp.int32)
f(shape0, shape1)
mattjj commented 1 month ago

This is trickier than I thought; if we fix that local issue (e.g. by replacing len(a) with a.shape[0]) then we get a downstream issue that the searchsorted implementations rely on static sizes (or at least size bounds).

@jakevdp any ideas?

jakevdp commented 1 month ago

Not sure... it's still not clear to me to what extent we should expect shape polymorphism to be supported in JAX APIs. Do we have those goals documented anywhere? It's pretty incomplete at the moment: if we opened issues like this one for every numpy API that doesn't support shape polymophism, we wouldn't have time to do anything else 😀