google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

`jnp.arange` does not permit dynamic shaped arrays #23423

Open josh146 opened 1 week ago

josh146 commented 1 week ago

Description

I've noticed that, when enabling JAX dynamic shape support via

jax.config.update("jax_dynamic_shapes", True)

jnp.arange (and similarly, jnp.linspace) both error if passed dynamic variables (which would generate a dynamic shaped array).

I'm wondering if this is a bug, as I noticed that in the source code, a non-concrete error is being raised in both the dynamic_shape=False and dynamic_shape=True case?

https://github.com/google/jax/blob/1e01fa7b0f1355c522f8420569dc778f2633c629/jax/_src/numpy/lax_numpy.py#L3143-L3154

Otherwise, if this is intentional, I can update this bug report to instead be a feature request.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.9 (main, Jan 11 2023, 09:18:18) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='MBA-HJXGDYDXH7', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:04 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6020', machine='arm64')
josh146 commented 1 week ago

Note: for reference, my use case does not involve XLA (which does not support dynamic shaped arrays). instead, I am compiling the generated MHLO via LLVM.

jakevdp commented 1 week ago

dynamic shapes are still very experimental and don't have much support in JAX APIs. Assigning to @mattjj because he may know whether or not it's expected to work here.

mattjj commented 5 days ago

We haven't prioritized dynamic shapes work for a while, and so the only things available are bits from our past experiments. That said, it is often easy to make specific things work, and so I'm happy to hear specific feature requests like this (e.g. "make my jnp.arange call work with dynamic shapes"). (I'm calling it a feature request rather than a bug because the docs don't say this should work, i.e. this is "intentional" as you say.)

In this case, we only made the jnp.arange function work in its single-argument form:

import jax
jax.config.update('jax_dynamic_shapes', True)

jaxpr = jax.make_jaxpr(jnp.arange)(5)
print(jaxpr)
{ lambda ; a:i32[]. let
    b:i32[a] = iota[dimension=0 dtype=int32 shape=(None,)] a
  in (b,) }

In the code you linked, you can see that we only check that start is concrete if dynamic_shapes is False. But we always check that stop and step are None or concrete.

What signature of jnp.arange did you need?