Open ayaka14732 opened 4 months ago
Better repro (without strided indexing):
import functools
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import numpy as np
x_shape = (16, 3)
x = jnp.arange(np.prod(x_shape), dtype=jnp.float32).reshape(x_shape)
a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32)
y = x[:, a]
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
interpret=True,
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[:, a]
y_ = kernel(x)
np.testing.assert_array_equal(y_, y)
Error:
Traceback (most recent call last):
File "/home/ayx/development/jax/4.py", line 23, in <module>
y_ = kernel(x)
^^^^^^^^^
File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 1129, in wrapped
jaxpr = _trace_kernel_to_jaxpr(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 901, in _trace_kernel_to_jaxpr
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/development/jax/4.py", line 21, in kernel
o_ref[...] = x_ref[:, a]
~~~~~^^^^^
File "/home/ayx/development/jax/jax/_src/numpy/array_methods.py", line 749, in op
return getattr(self.aval, f"_{name}")(self, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/development/jax/jax/_src/state/types.py", line 187, in _setitem
return ref_set(tracer, idx, value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 114, in ref_set
ref_swap(ref_or_view, idx, value, _function_name="ref_set")
File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 110, in ref_swap
return swap_p.bind(ref, value, *flat_indexers, tree=tree)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 178, in _swap_abstract_eval
raise ValueError("Invalid shape for `swap`. "
ValueError: Invalid shape for `swap`. Ref shape: (16, 5). Expected shape: (16, 5). Value shape: (5, 16). Indices: (NDIndexer(indices=(Slice(start=0, size=16, stride=1), Slice(start=0, size=5, stride=1)), shape=(16, 5), int_indexer_shape=(), validate=False),).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Description
I am testing in interpret mode.
Repro:
Expected behavior:
The line
y_ = kernel(x)
should run successfully, and yield the same value asy
.Actual behavior:
Explanation:
The correct shape of the resulting array should be (4, 5), but in Pallas, the shape is incorrectly assumed to be (5, 4), thus resulting the error.
I have tested various indexing and observed a pattern that when there is only 1 integer array indexing, the axis corresponding to it is always unnecessarily moved to the front. For example, in the above case, the axis with shape 5 is moved to the front, making Pallas to assume the shape to be (5, 4) instead of (4, 5).
This may have to do with https://github.com/google/jax/blob/5c9bb612a775ca23d311eef1aeac03dfe0828a62/jax/_src/state/indexing.py#L256-L257.
System info (python version, jaxlib version, accelerator, etc.)