Open josh146 opened 1 week ago
Note: for reference, my use case does not involve XLA (which does not support dynamic shaped arrays). instead, I am compiling the generated MHLO via LLVM.
dynamic shapes are still very experimental and don't have much support in JAX APIs. Assigning to @mattjj because he may know whether or not it's expected to work here.
We haven't prioritized dynamic shapes work for a while, and so the only things available are bits from our past experiments. That said, it is often easy to make specific things work, and so I'm happy to hear specific feature requests like this (e.g. "make my jnp.arange call work with dynamic shapes"). (I'm calling it a feature request rather than a bug because the docs don't say this should work, i.e. this is "intentional" as you say.)
In this case, we only made the jnp.arange
function work in its single-argument form:
import jax
jax.config.update('jax_dynamic_shapes', True)
jaxpr = jax.make_jaxpr(jnp.arange)(5)
print(jaxpr)
{ lambda ; a:i32[]. let
b:i32[a] = iota[dimension=0 dtype=int32 shape=(None,)] a
in (b,) }
In the code you linked, you can see that we only check that start
is concrete if dynamic_shapes
is False. But we always check that stop
and step
are None or concrete.
What signature of jnp.arange
did you need?
Description
I've noticed that, when enabling JAX dynamic shape support via
jnp.arange
(and similarly,jnp.linspace
) both error if passed dynamic variables (which would generate a dynamic shaped array).I'm wondering if this is a bug, as I noticed that in the source code, a non-concrete error is being raised in both the
dynamic_shape=False
anddynamic_shape=True
case?https://github.com/google/jax/blob/1e01fa7b0f1355c522f8420569dc778f2633c629/jax/_src/numpy/lax_numpy.py#L3143-L3154
Otherwise, if this is intentional, I can update this bug report to instead be a feature request.
System info (python version, jaxlib version, accelerator, etc.)